Commit 5a2bbc57 authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy
Browse files

First cut refactoring

(cherry picked from commit 71afec427baca8e37cd9e10d98812bc586e9a4ac)
parent 8e375670
...@@ -13,43 +13,7 @@ import types ...@@ -13,43 +13,7 @@ import types
import collections import collections
class Compose(object): def to_tensor(pic):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
class ToTensor(object):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, np.ndarray): if isinstance(pic, np.ndarray):
# handle numpy array # handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1))) img = torch.from_numpy(pic.transpose((2, 0, 1)))
...@@ -85,22 +49,7 @@ class ToTensor(object): ...@@ -85,22 +49,7 @@ class ToTensor(object):
return img return img
class ToPILImage(object): def to_pilimage(pic):
"""Convert a tensor to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
"""
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
npimg = pic npimg = pic
mode = None mode = None
if isinstance(pic, torch.FloatTensor): if isinstance(pic, torch.FloatTensor):
...@@ -126,6 +75,109 @@ class ToPILImage(object): ...@@ -126,6 +75,109 @@ class ToPILImage(object):
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=mode)
def normalize(tensor, mean, std):
# TODO: make efficient
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def scale(img, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size, interpolation)
def pad(img, padding, fill=0):
assert isinstance(padding, numbers.Number)
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
return ImageOps.expand(img, border=padding, fill=fill)
def crop(img, x, y, w, h):
return img.crop((x, y, x + w, y + h))
def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
img = crop(img, x, y, w, h)
img = scale(img, size, interpolation)
def hflip(img):
return img.transpose(Image.FLIP_LEFT_RIGHT)
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
class ToTensor(object):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return to_tensor(pic)
class ToPILImage(object):
"""Convert a tensor to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
"""
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
return to_pilimage(pic)
class Normalize(object): class Normalize(object):
"""Normalize an tensor image with mean and standard deviation. """Normalize an tensor image with mean and standard deviation.
...@@ -151,10 +203,7 @@ class Normalize(object): ...@@ -151,10 +203,7 @@ class Normalize(object):
Returns: Returns:
Tensor: Normalized image. Tensor: Normalized image.
""" """
# TODO: make efficient return normalize(tensor, self.mean, self.std)
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
return tensor
class Scale(object): class Scale(object):
...@@ -183,20 +232,7 @@ class Scale(object): ...@@ -183,20 +232,7 @@ class Scale(object):
Returns: Returns:
PIL.Image: Rescaled image. PIL.Image: Rescaled image.
""" """
if isinstance(self.size, int): return scale(img, self.size, self.interpolation)
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
ow = self.size
oh = int(self.size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
oh = self.size
ow = int(self.size * w / h)
return img.resize((ow, oh), self.interpolation)
else:
return img.resize(self.size, self.interpolation)
class CenterCrop(object): class CenterCrop(object):
...@@ -214,6 +250,13 @@ class CenterCrop(object): ...@@ -214,6 +250,13 @@ class CenterCrop(object):
else: else:
self.size = size self.size = size
def get_params(self, img):
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return x1, y1, tw, th
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
...@@ -222,11 +265,8 @@ class CenterCrop(object): ...@@ -222,11 +265,8 @@ class CenterCrop(object):
Returns: Returns:
PIL.Image: Cropped image. PIL.Image: Cropped image.
""" """
w, h = img.size x1, y1, tw, th = self.get_params(img)
th, tw = self.size return crop(img, x1, y1, tw, th)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
class Pad(object): class Pad(object):
...@@ -260,7 +300,7 @@ class Pad(object): ...@@ -260,7 +300,7 @@ class Pad(object):
Returns: Returns:
PIL.Image: Padded image. PIL.Image: Padded image.
""" """
return ImageOps.expand(img, border=self.padding, fill=self.fill) return pad(img, self.padding, self.fill)
class Lambda(object): class Lambda(object):
...@@ -298,6 +338,16 @@ class RandomCrop(object): ...@@ -298,6 +338,16 @@ class RandomCrop(object):
self.size = size self.size = size
self.padding = padding self.padding = padding
def get_params(self, img):
w, h = img.size
th, tw = self.size
if w == tw and h == th:
return img
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return x1, y1, tw, th
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
...@@ -307,16 +357,11 @@ class RandomCrop(object): ...@@ -307,16 +357,11 @@ class RandomCrop(object):
PIL.Image: Cropped image. PIL.Image: Cropped image.
""" """
if self.padding > 0: if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0) img = pad(img, self.padding)
w, h = img.size x1, y1, tw, th = self.get_params(img)
th, tw = self.size
if w == tw and h == th:
return img
x1 = random.randint(0, w - tw) return crop(img, x1, y1, tw, th)
y1 = random.randint(0, h - th)
return img.crop((x1, y1, x1 + tw, y1 + th))
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
...@@ -331,7 +376,7 @@ class RandomHorizontalFlip(object): ...@@ -331,7 +376,7 @@ class RandomHorizontalFlip(object):
PIL.Image: Randomly flipped image. PIL.Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT) return hflip(img)
return img return img
...@@ -352,7 +397,7 @@ class RandomSizedCrop(object): ...@@ -352,7 +397,7 @@ class RandomSizedCrop(object):
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
def __call__(self, img): def get_params(self, img):
for attempt in range(10): for attempt in range(10):
area = img.size[0] * img.size[1] area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area target_area = random.uniform(0.08, 1.0) * area
...@@ -365,15 +410,16 @@ class RandomSizedCrop(object): ...@@ -365,15 +410,16 @@ class RandomSizedCrop(object):
w, h = h, w w, h = h, w
if w <= img.size[0] and h <= img.size[1]: if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w) x = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h) y = random.randint(0, img.size[1] - h)
return x, y, w, h
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((self.size, self.size), self.interpolation)
# Fallback # Fallback
scale = Scale(self.size, interpolation=self.interpolation) w = min(img.size[0], img.shape[1])
crop = CenterCrop(self.size) x = (img.shape[0] - w) // 2
return crop(scale(img)) y = (img.shape[1] - w) // 2
return x, y, w, w
def __call__(self, img):
x, y, w, h = self.get_params(img)
return scaled_crop(img, x, y, w, h, self.size, self.interpolation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment