Commit 1f2c15f2 authored by Sam Gross's avatar Sam Gross Committed by Soumith Chintala
Browse files

Add support for accimage.Image (#153)

It can be enabled by calling torchvision.set_image_backend('accimage')
parent 323f5294
......@@ -4,6 +4,7 @@ from timeit import default_timer as timer
from tqdm import tqdm
import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
......@@ -15,11 +16,17 @@ parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N',
help='number of data loading threads (default: 2)')
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
help='mini-batch size (1 = pure stochastic) Default: 256')
parser.add_argument('--accimage', action='store_true',
help='use accimage')
if __name__ == "__main__":
args = parser.parse_args()
if args.accimage:
torchvision.set_image_backend('accimage')
print('Using {}'.format(torchvision.get_image_backend()))
# Data loading code
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
......@@ -38,11 +45,13 @@ if __name__ == "__main__":
train_iter = iter(train_loader)
start_time = timer()
batch_count = 100 * args.nThreads
for i in tqdm(xrange(batch_count)):
batch_count = 20 * args.nThreads
for _ in tqdm(range(batch_count)):
batch = next(train_iter)
end_time = timer()
print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format(
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
batch=(end_time - start_time) / float(batch_count),
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
" {image:.2f} ms/image {rate:.0f} images/sec"
.format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
batch=(end_time - start_time) / float(batch_count) * 1.0e+3,
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3,
rate=(batch_count * args.batchSize) / (end_time - start_time)))
......@@ -3,6 +3,14 @@ import torchvision.transforms as transforms
import unittest
import random
import numpy as np
from PIL import Image
try:
import accimage
except ImportError:
accimage = None
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg'
class Tester(unittest.TestCase):
......@@ -153,6 +161,45 @@ class Tester(unittest.TestCase):
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output)
@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
trans = transforms.ToTensor()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())
@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_resize(self):
trans = transforms.Compose([
transforms.Scale(256, interpolation=Image.LINEAR),
transforms.ToTensor(),
])
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size())
self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
self.assertLess((expected_output - output).var(), 1e-5)
# note the high absolute tolerance
assert np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2)
@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_crop(self):
trans = transforms.Compose([
transforms.CenterCrop(256),
transforms.ToTensor(),
])
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size())
assert np.allclose(output.numpy(), expected_output.numpy())
def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
......
......@@ -2,3 +2,31 @@ from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchvision import utils
_image_backend = 'PIL'
def set_image_backend(backend):
"""
Specifies the package used to load images.
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
Intel IPP library. It is generally faster than PIL, but does not support as
many operations.
Args:
backend (string): name of the image backend
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.format(backend))
_image_backend = backend
def get_image_backend():
"""
Gets the name of the package used to load images
"""
return _image_backend
......@@ -38,10 +38,27 @@ def make_dataset(dir, class_to_idx):
return images
def default_loader(path):
def pil_loader(path):
return Image.open(path).convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
......
......@@ -3,6 +3,10 @@ import torch
import math
import random
from PIL import Image, ImageOps
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
......@@ -42,6 +46,12 @@ class ToTensor(object):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backard compability
return img.float().div(255)
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment