Commit 7aeec57f authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy
Browse files

Asserts for functions

parent bf38166d
......@@ -13,7 +13,24 @@ import types
import collections
def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
def to_tensor(pic):
assert _is_pil_image(pic) or _is_numpy_image(pic), 'pic should be PIL Image or ndarray'
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
......@@ -50,13 +67,15 @@ def to_tensor(pic):
def to_pilimage(pic):
assert _is_numpy_image(pic) or _is_tensor_image(pic), 'pic should be Tensor or ndarray'
npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
assert isinstance(npimg, np.ndarray)
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
......@@ -76,6 +95,7 @@ def to_pilimage(pic):
def normalize(tensor, mean, std):
assert _is_tensor_image(tensor)
# TODO: make efficient
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
......@@ -83,6 +103,7 @@ def normalize(tensor, mean, std):
def scale(img, size, interpolation=Image.BILINEAR):
assert _is_pil_image(img), 'img should be PIL Image'
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
if isinstance(size, int):
w, h = img.size
......@@ -101,6 +122,7 @@ def scale(img, size, interpolation=Image.BILINEAR):
def pad(img, padding, fill=0):
assert _is_pil_image(img), 'img should be PIL Image'
assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, (numbers.Number, str, tuple))
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
......@@ -111,15 +133,18 @@ def pad(img, padding, fill=0):
def crop(img, x, y, w, h):
assert _is_pil_image(img), 'img should be PIL Image'
return img.crop((x, y, x + w, y + h))
def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, x, y, w, h)
img = scale(img, size, interpolation)
def hflip(img):
assert _is_pil_image(img), 'img should be PIL Image'
return img.transpose(Image.FLIP_LEFT_RIGHT)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment