Commit 4390b559 authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy
Browse files

Make get_params static method

parent 8b18f526
...@@ -67,7 +67,7 @@ def to_tensor(pic): ...@@ -67,7 +67,7 @@ def to_tensor(pic):
return img return img
def to_pilimage(pic): def to_pil_image(pic):
if not(_is_numpy_image(pic) or _is_tensor_image(pic)): if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
...@@ -219,7 +219,7 @@ class ToPILImage(object): ...@@ -219,7 +219,7 @@ class ToPILImage(object):
PIL.Image: Image converted to PIL.Image. PIL.Image: Image converted to PIL.Image.
""" """
return to_pilimage(pic) return to_pil_image(pic)
class Normalize(object): class Normalize(object):
...@@ -294,9 +294,10 @@ class CenterCrop(object): ...@@ -294,9 +294,10 @@ class CenterCrop(object):
else: else:
self.size = size self.size = size
def get_params(self, img): @staticmethod
def get_params(img, output_size):
w, h = img.size w, h = img.size
th, tw = self.size th, tw = output_size
x1 = int(round((w - tw) / 2.)) x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.)) y1 = int(round((h - th) / 2.))
return x1, y1, tw, th return x1, y1, tw, th
...@@ -309,7 +310,7 @@ class CenterCrop(object): ...@@ -309,7 +310,7 @@ class CenterCrop(object):
Returns: Returns:
PIL.Image: Cropped image. PIL.Image: Cropped image.
""" """
x1, y1, tw, th = self.get_params(img) x1, y1, tw, th = self.get_params(img, self.size)
return crop(img, x1, y1, tw, th) return crop(img, x1, y1, tw, th)
...@@ -382,9 +383,10 @@ class RandomCrop(object): ...@@ -382,9 +383,10 @@ class RandomCrop(object):
self.size = size self.size = size
self.padding = padding self.padding = padding
def get_params(self, img): @staticmethod
def get_params(img, output_size):
w, h = img.size w, h = img.size
th, tw = self.size th, tw = output_size
if w == tw and h == th: if w == tw and h == th:
return img return img
...@@ -403,7 +405,7 @@ class RandomCrop(object): ...@@ -403,7 +405,7 @@ class RandomCrop(object):
if self.padding > 0: if self.padding > 0:
img = pad(img, self.padding) img = pad(img, self.padding)
x1, y1, tw, th = self.get_params(img) x1, y1, tw, th = self.get_params(img, self.size)
return crop(img, x1, y1, tw, th) return crop(img, x1, y1, tw, th)
...@@ -441,7 +443,8 @@ class RandomSizedCrop(object): ...@@ -441,7 +443,8 @@ class RandomSizedCrop(object):
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
def get_params(self, img): @staticmethod
def get_params(img):
for attempt in range(10): for attempt in range(10):
area = img.size[0] * img.size[1] area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area target_area = random.uniform(0.08, 1.0) * area
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment