transforms.py 18.1 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
31
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
33
34
35
36
37
38
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
42
43
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
44

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


80
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
81
82
83
84
85
86
87
88
89
90
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

    Returns:
        PIL.Image: Image converted to PIL.Image.
    """
91
92
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
93

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94
95
96
97
98
99
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
100
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
120
121
122
123
124
125
126
127
128
129
130
131
132
    """Normalize an tensor image with mean and standard deviation.

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
133
134
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
135
136
137
138
139
140
141
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
142
143
144
145
146
    """Rescale the input PIL.Image to the given size.

    Args:
        img (PIL.Image): Image to be scaled.
        size (sequence or int): Desired output size. If size is a sequence like
147
            (h, w), output size will be matched to this. If size is an int,
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
148
149
150
151
152
153
154
155
156
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
        PIL.Image: Rescaled image.
    """
157
158
159
160
161
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
162
163
164
165
166
167
168
169
170
171
172
173
174
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
175
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
176
177
178


def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        img (PIL.Image): Image to be padded.
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
        PIL.Image: Padded image.
    """
194
195
196
197
198
199
200
201
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
202
203
204
205
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
206
207
208
209
    return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
210
211
212
213
214
215
216
217
218
219
220
221
    """Crop the given PIL.Image.

    Args:
        img (PIL.Image): Image to be cropped.
        x: Left pixel coordinate.
        y: Upper pixel coordinate.
        w: Width of the cropped image.
        h: Height of the cropped image.

    Returns:
        PIL.Image: Cropped image.
    """
222
223
224
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
225
226
227
228
    return img.crop((x, y, x + w, y + h))


def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    """Crop the given PIL.Image and scale it to desired size.

    Notably used in RandomSizedCrop.

    Args:
        img (PIL.Image): Image to be cropped.
        x: Left pixel coordinate.
        y: Upper pixel coordinate.
        w: Width of the cropped image.
        h: Height of the cropped image.
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
        PIL.Image: Cropped image.
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
245
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
246
247
    img = crop(img, x, y, w, h)
    img = scale(img, size, interpolation)
248
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
249
250
251


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
252
253
254
255
256
257
258
259
    """Horizontally flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Horizontall flipped image.
    """
260
261
262
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
263
264
265
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
266
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
267
268
269
    """Composes several transforms together.

    Args:
270
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
271
272
273
274
275
276

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
277
    """
278

soumith's avatar
soumith committed
279
280
281
282
283
284
285
286
287
288
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
289
290
291
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
292
293
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
294

soumith's avatar
soumith committed
295
    def __call__(self, pic):
296
297
298
299
300
301
302
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
303
        return to_tensor(pic)
304

Adam Paszke's avatar
Adam Paszke committed
305

306
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
307
    """Convert a tensor or an ndarray to PIL Image.
308
309
310

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
311
    """
312

313
    def __call__(self, pic):
314
315
316
317
318
319
320
321
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
322
        return to_pil_image(pic)
323

soumith's avatar
soumith committed
324
325

class Normalize(object):
326
327
328
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
329
330
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
331
332
333
334
335

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
336
    """
337

soumith's avatar
soumith committed
338
339
340
341
342
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
343
344
345
346
347
348
349
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
350
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
351
352
353


class Scale(object):
354
355
356
357
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
358
            (h, w), output size will be matched to this. If size is an int,
359
360
361
362
363
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
364
    """
365

soumith's avatar
soumith committed
366
    def __init__(self, size, interpolation=Image.BILINEAR):
367
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
368
369
370
371
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
372
373
374
375
376
377
378
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
379
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
380
381
382


class CenterCrop(object):
383
384
385
386
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
387
            int instead of sequence like (h, w), a square crop (size, size) is
388
            made.
389
    """
390

soumith's avatar
soumith committed
391
    def __init__(self, size):
392
393
394
395
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
396

397
398
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
399
400
401
402
403
404
405
406
407
        """Get parameters for ``crop`` for center crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (x, y, w, h) to be passed to ``crop`` for center crop.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
408
        w, h = img.size
409
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
410
411
412
413
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return x1, y1, tw, th

soumith's avatar
soumith committed
414
    def __call__(self, img):
415
416
417
418
419
420
421
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
422
        x1, y1, tw, th = self.get_params(img, self.size)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
423
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
424
425


426
class Pad(object):
427
428
429
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
430
431
432
433
434
435
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
436
            length 3, it is used to fill R, G, B channels respectively.
437
    """
438

439
    def __init__(self, padding, fill=0):
440
441
442
443
444
445
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

446
447
448
449
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
450
451
452
453
454
455
456
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
457
        return pad(img, self.padding, self.fill)
458

459

Soumith Chintala's avatar
Soumith Chintala committed
460
class Lambda(object):
461
462
463
464
465
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
466

Soumith Chintala's avatar
Soumith Chintala committed
467
    def __init__(self, lambd):
468
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
469
470
471
472
473
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

474

soumith's avatar
soumith committed
475
class RandomCrop(object):
476
477
478
479
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
480
            int instead of sequence like (h, w), a square crop (size, size) is
481
482
483
484
485
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
486
    """
487

soumith's avatar
soumith committed
488
    def __init__(self, size, padding=0):
489
490
491
492
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
493
494
        self.padding = padding

495
496
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
497
498
499
500
501
502
503
504
505
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (x, y, w, h) to be passed to ``crop`` for random crop.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
506
        w, h = img.size
507
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
508
509
510
511
512
513
514
        if w == tw and h == th:
            return img

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, tw, th

soumith's avatar
soumith committed
515
    def __call__(self, img):
516
517
518
519
520
521
522
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
523
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
524
            img = pad(img, self.padding)
soumith's avatar
soumith committed
525

526
        x1, y1, tw, th = self.get_params(img, self.size)
soumith's avatar
soumith committed
527

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
528
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
529
530
531


class RandomHorizontalFlip(object):
532
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
533

soumith's avatar
soumith committed
534
    def __call__(self, img):
535
536
537
538
539
540
541
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
542
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
543
            return hflip(img)
soumith's avatar
soumith committed
544
545
546
547
        return img


class RandomSizedCrop(object):
548
549
550
551
552
553
554
555
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
556
        size: expected output size of each edge
557
        interpolation: Default: PIL.Image.BILINEAR
558
    """
559

soumith's avatar
soumith committed
560
    def __init__(self, size, interpolation=Image.BILINEAR):
561
        self.size = (size, size)
soumith's avatar
soumith committed
562
563
        self.interpolation = interpolation

564
565
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
566
567
568
569
570
571
572
573
574
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            tuple: params (x, y, w, h) to be passed to ``crop`` for a random
                sized crop.
        """
soumith's avatar
soumith committed
575
576
577
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
578
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
579
580
581
582
583
584
585
586

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
587
588
589
                x = random.randint(0, img.size[0] - w)
                y = random.randint(0, img.size[1] - h)
                return x, y, w, h
soumith's avatar
soumith committed
590
591

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
592
593
594
595
596
597
        w = min(img.size[0], img.shape[1])
        x = (img.shape[0] - w) // 2
        y = (img.shape[1] - w) // 2
        return x, y, w, w

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
598
599
600
601
602
603
604
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly cropped and scaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
605
606
        x, y, w, h = self.get_params(img)
        return scaled_crop(img, x, y, w, h, self.size, self.interpolation)