transforms.py 18.1 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
31
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
33
34
35
36
37
38
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
42
43
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
44

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


80
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
81
82
83
84
85
86
87
88
89
90
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

    Returns:
        PIL.Image: Image converted to PIL.Image.
    """
91
92
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
93

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94
95
96
97
98
99
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
100
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
120
121
122
123
124
125
126
127
128
129
130
131
132
    """Normalize an tensor image with mean and standard deviation.

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
133
134
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
135
136
137
138
139
140
141
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    """Rescale the input PIL.Image to the given size.

    Args:
        img (PIL.Image): Image to be scaled.
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
        PIL.Image: Rescaled image.
    """
157
158
159
160
161
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size, interpolation)


def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        img (PIL.Image): Image to be padded.
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
        PIL.Image: Padded image.
    """
194
195
196
197
198
199
200
201
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
202
203
204
205
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
206
207
208
209
    return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
210
211
212
213
214
215
216
217
218
219
220
221
    """Crop the given PIL.Image.

    Args:
        img (PIL.Image): Image to be cropped.
        x: Left pixel coordinate.
        y: Upper pixel coordinate.
        w: Width of the cropped image.
        h: Height of the cropped image.

    Returns:
        PIL.Image: Cropped image.
    """
222
223
224
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
225
226
227
228
    return img.crop((x, y, x + w, y + h))


def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    """Crop the given PIL.Image and scale it to desired size.

    Notably used in RandomSizedCrop.

    Args:
        img (PIL.Image): Image to be cropped.
        x: Left pixel coordinate.
        y: Upper pixel coordinate.
        w: Width of the cropped image.
        h: Height of the cropped image.
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
        PIL.Image: Cropped image.
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
245
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
246
247
248
249
250
    img = crop(img, x, y, w, h)
    img = scale(img, size, interpolation)


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
251
252
253
254
255
256
257
258
    """Horizontally flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Horizontall flipped image.
    """
259
260
261
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
262
263
264
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
265
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
266
267
268
    """Composes several transforms together.

    Args:
269
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
270
271
272
273
274
275

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
276
    """
277

soumith's avatar
soumith committed
278
279
280
281
282
283
284
285
286
287
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
288
289
290
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
291
292
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
293

soumith's avatar
soumith committed
294
    def __call__(self, pic):
295
296
297
298
299
300
301
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
302
        return to_tensor(pic)
303

Adam Paszke's avatar
Adam Paszke committed
304

305
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
306
    """Convert a tensor or an ndarray to PIL Image.
307
308
309

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
310
    """
311

312
    def __call__(self, pic):
313
314
315
316
317
318
319
320
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
321
        return to_pil_image(pic)
322

soumith's avatar
soumith committed
323
324

class Normalize(object):
325
326
327
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
328
329
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
330
331
332
333
334

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
335
    """
336

soumith's avatar
soumith committed
337
338
339
340
341
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
342
343
344
345
346
347
348
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
349
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
350
351
352


class Scale(object):
353
354
355
356
357
358
359
360
361
362
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
363
    """
364

soumith's avatar
soumith committed
365
    def __init__(self, size, interpolation=Image.BILINEAR):
366
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
367
368
369
370
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
371
372
373
374
375
376
377
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
378
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
379
380
381


class CenterCrop(object):
382
383
384
385
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
386
            int instead of sequence like (h, w), a square crop (size, size) is
387
            made.
388
    """
389

soumith's avatar
soumith committed
390
    def __init__(self, size):
391
392
393
394
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
395

396
397
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
398
399
400
401
402
403
404
405
406
        """Get parameters for ``crop`` for center crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (x, y, w, h) to be passed to ``crop`` for center crop.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
407
        w, h = img.size
408
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
409
410
411
412
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return x1, y1, tw, th

soumith's avatar
soumith committed
413
    def __call__(self, img):
414
415
416
417
418
419
420
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
421
        x1, y1, tw, th = self.get_params(img, self.size)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
422
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
423
424


425
class Pad(object):
426
427
428
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
429
430
431
432
433
434
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
435
            length 3, it is used to fill R, G, B channels respectively.
436
    """
437

438
    def __init__(self, padding, fill=0):
439
440
441
442
443
444
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

445
446
447
448
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
449
450
451
452
453
454
455
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
456
        return pad(img, self.padding, self.fill)
457

458

Soumith Chintala's avatar
Soumith Chintala committed
459
class Lambda(object):
460
461
462
463
464
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
465

Soumith Chintala's avatar
Soumith Chintala committed
466
    def __init__(self, lambd):
467
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
468
469
470
471
472
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

473

soumith's avatar
soumith committed
474
class RandomCrop(object):
475
476
477
478
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
479
            int instead of sequence like (h, w), a square crop (size, size) is
480
481
482
483
484
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
485
    """
486

soumith's avatar
soumith committed
487
    def __init__(self, size, padding=0):
488
489
490
491
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
492
493
        self.padding = padding

494
495
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
496
497
498
499
500
501
502
503
504
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (x, y, w, h) to be passed to ``crop`` for random crop.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
505
        w, h = img.size
506
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
507
508
509
510
511
512
513
        if w == tw and h == th:
            return img

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, tw, th

soumith's avatar
soumith committed
514
    def __call__(self, img):
515
516
517
518
519
520
521
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
522
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
523
            img = pad(img, self.padding)
soumith's avatar
soumith committed
524

525
        x1, y1, tw, th = self.get_params(img, self.size)
soumith's avatar
soumith committed
526

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
527
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
528
529
530


class RandomHorizontalFlip(object):
531
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
532

soumith's avatar
soumith committed
533
    def __call__(self, img):
534
535
536
537
538
539
540
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
541
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
542
            return hflip(img)
soumith's avatar
soumith committed
543
544
545
546
        return img


class RandomSizedCrop(object):
547
548
549
550
551
552
553
554
555
556
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
557
    """
558

soumith's avatar
soumith committed
559
560
561
562
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

563
564
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
565
566
567
568
569
570
571
572
573
574
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (x, y, w, h) to be passed to ``crop`` for a random
                sized crop.
        """
soumith's avatar
soumith committed
575
576
577
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
578
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
579
580
581
582
583
584
585
586

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
587
588
589
                x = random.randint(0, img.size[0] - w)
                y = random.randint(0, img.size[1] - h)
                return x, y, w, h
soumith's avatar
soumith committed
590
591

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
592
593
594
595
596
597
        w = min(img.size[0], img.shape[1])
        x = (img.shape[0] - w) // 2
        y = (img.shape[1] - w) // 2
        return x, y, w, w

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
598
599
600
601
602
603
604
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly cropped and scaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
605
606
        x, y, w, h = self.get_params(img)
        return scaled_crop(img, x, y, w, h, self.size, self.interpolation)