Commit 1f2c15f2 authored by Sam Gross's avatar Sam Gross Committed by Soumith Chintala
Browse files

Add support for accimage.Image (#153)

It can be enabled by calling torchvision.set_image_backend('accimage')
parent 323f5294
...@@ -4,6 +4,7 @@ from timeit import default_timer as timer ...@@ -4,6 +4,7 @@ from timeit import default_timer as timer
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.datasets as datasets import torchvision.datasets as datasets
...@@ -15,11 +16,17 @@ parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', ...@@ -15,11 +16,17 @@ parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N',
help='number of data loading threads (default: 2)') help='number of data loading threads (default: 2)')
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
help='mini-batch size (1 = pure stochastic) Default: 256') help='mini-batch size (1 = pure stochastic) Default: 256')
parser.add_argument('--accimage', action='store_true',
help='use accimage')
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.accimage:
torchvision.set_image_backend('accimage')
print('Using {}'.format(torchvision.get_image_backend()))
# Data loading code # Data loading code
transform = transforms.Compose([ transform = transforms.Compose([
transforms.RandomSizedCrop(224), transforms.RandomSizedCrop(224),
...@@ -38,11 +45,13 @@ if __name__ == "__main__": ...@@ -38,11 +45,13 @@ if __name__ == "__main__":
train_iter = iter(train_loader) train_iter = iter(train_loader)
start_time = timer() start_time = timer()
batch_count = 100 * args.nThreads batch_count = 20 * args.nThreads
for i in tqdm(xrange(batch_count)): for _ in tqdm(range(batch_count)):
batch = next(train_iter) batch = next(train_iter)
end_time = timer() end_time = timer()
print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format( print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), " {image:.2f} ms/image {rate:.0f} images/sec"
batch=(end_time - start_time) / float(batch_count), .format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3)) batch=(end_time - start_time) / float(batch_count) * 1.0e+3,
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3,
rate=(batch_count * args.batchSize) / (end_time - start_time)))
...@@ -3,6 +3,14 @@ import torchvision.transforms as transforms ...@@ -3,6 +3,14 @@ import torchvision.transforms as transforms
import unittest import unittest
import random import random
import numpy as np import numpy as np
from PIL import Image
try:
import accimage
except ImportError:
accimage = None
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -153,6 +161,45 @@ class Tester(unittest.TestCase): ...@@ -153,6 +161,45 @@ class Tester(unittest.TestCase):
expected_output = ndarray.transpose((2, 0, 1)) / 255.0 expected_output = ndarray.transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output) assert np.allclose(output.numpy(), expected_output)
@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
trans = transforms.ToTensor()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())
@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_resize(self):
trans = transforms.Compose([
transforms.Scale(256, interpolation=Image.LINEAR),
transforms.ToTensor(),
])
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size())
self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
self.assertLess((expected_output - output).var(), 1e-5)
# note the high absolute tolerance
assert np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2)
@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_crop(self):
trans = transforms.Compose([
transforms.CenterCrop(256),
transforms.ToTensor(),
])
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())
def test_tensor_to_pil_image(self): def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage() trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
......
...@@ -2,3 +2,31 @@ from torchvision import models ...@@ -2,3 +2,31 @@ from torchvision import models
from torchvision import datasets from torchvision import datasets
from torchvision import transforms from torchvision import transforms
from torchvision import utils from torchvision import utils
_image_backend = 'PIL'
def set_image_backend(backend):
"""
Specifies the package used to load images.
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
Intel IPP library. It is generally faster than PIL, but does not support as
many operations.
Args:
backend (string): name of the image backend
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.format(backend))
_image_backend = backend
def get_image_backend():
"""
Gets the name of the package used to load images
"""
return _image_backend
...@@ -38,10 +38,27 @@ def make_dataset(dir, class_to_idx): ...@@ -38,10 +38,27 @@ def make_dataset(dir, class_to_idx):
return images return images
def default_loader(path): def pil_loader(path):
return Image.open(path).convert('RGB') return Image.open(path).convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(data.Dataset): class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
......
...@@ -3,6 +3,10 @@ import torch ...@@ -3,6 +3,10 @@ import torch
import math import math
import random import random
from PIL import Image, ImageOps from PIL import Image, ImageOps
try:
import accimage
except ImportError:
accimage = None
import numpy as np import numpy as np
import numbers import numbers
import types import types
...@@ -42,6 +46,12 @@ class ToTensor(object): ...@@ -42,6 +46,12 @@ class ToTensor(object):
img = torch.from_numpy(pic.transpose((2, 0, 1))) img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backard compability # backard compability
return img.float().div(255) return img.float().div(255)
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image # handle PIL Image
if pic.mode == 'I': if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) img = torch.from_numpy(np.array(pic, np.int32, copy=False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment