transforms.py 18.6 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
31
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
33
34
35
36
37
38
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
42
43
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
44

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


80
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
81
82
83
84
85
86
87
88
89
90
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

    Returns:
        PIL.Image: Image converted to PIL.Image.
    """
91
92
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
93

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94
95
96
97
98
99
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
100
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
102
103
104
105
106
107
108
109
110
111
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
112
113
114
    elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
115
116
117
118
119
120
121
122
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
123
124
125
126
127
128
129
130
131
132
133
134
135
    """Normalize an tensor image with mean and standard deviation.

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
136
137
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
138
139
140
141
142
143
144
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
145
146
147
148
149
    """Rescale the input PIL.Image to the given size.

    Args:
        img (PIL.Image): Image to be scaled.
        size (sequence or int): Desired output size. If size is a sequence like
150
            (h, w), output size will be matched to this. If size is an int,
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
151
152
153
154
155
156
157
158
159
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
        PIL.Image: Rescaled image.
    """
160
161
162
163
164
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
165
166
167
168
169
170
171
172
173
174
175
176
177
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
178
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
179
180
181


def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        img (PIL.Image): Image to be padded.
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
        PIL.Image: Padded image.
    """
197
198
199
200
201
202
203
204
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
205
206
207
208
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
209
210
211
    return ImageOps.expand(img, border=padding, fill=fill)


212
def crop(img, i, j, h, w):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
213
214
215
216
    """Crop the given PIL.Image.

    Args:
        img (PIL.Image): Image to be cropped.
217
218
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
219
        h: Height of the cropped image.
220
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
221
222
223
224

    Returns:
        PIL.Image: Cropped image.
    """
225
226
227
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

228
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
229
230


231
def scaled_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
232
233
234
235
236
237
    """Crop the given PIL.Image and scale it to desired size.

    Notably used in RandomSizedCrop.

    Args:
        img (PIL.Image): Image to be cropped.
238
239
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
240
        h: Height of the cropped image.
241
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
242
243
244
245
246
247
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
        PIL.Image: Cropped image.
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
248
    assert _is_pil_image(img), 'img should be PIL Image'
249
    img = crop(img, i, j, h, w)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
250
    img = scale(img, size, interpolation)
251
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
252
253
254


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
255
256
257
258
259
260
261
262
    """Horizontally flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Horizontall flipped image.
    """
263
264
265
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
266
267
268
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
269
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
270
271
272
    """Composes several transforms together.

    Args:
273
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
274
275
276
277
278
279

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
280
    """
281

soumith's avatar
soumith committed
282
283
284
285
286
287
288
289
290
291
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
292
293
294
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
295
296
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
297

soumith's avatar
soumith committed
298
    def __call__(self, pic):
299
300
301
302
303
304
305
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
306
        return to_tensor(pic)
307

Adam Paszke's avatar
Adam Paszke committed
308

309
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
310
    """Convert a tensor or an ndarray to PIL Image.
311
312
313

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
314
    """
315

316
    def __call__(self, pic):
317
318
319
320
321
322
323
324
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
325
        return to_pil_image(pic)
326

soumith's avatar
soumith committed
327
328

class Normalize(object):
329
330
331
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
332
333
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
334
335
336
337
338

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
339
    """
340

soumith's avatar
soumith committed
341
342
343
344
345
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
346
347
348
349
350
351
352
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
353
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
354
355
356


class Scale(object):
357
358
359
360
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
361
            (h, w), output size will be matched to this. If size is an int,
362
363
364
365
366
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
367
    """
368

soumith's avatar
soumith committed
369
    def __init__(self, size, interpolation=Image.BILINEAR):
370
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
371
372
373
374
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
375
376
377
378
379
380
381
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
382
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
383
384
385


class CenterCrop(object):
386
387
388
389
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
390
            int instead of sequence like (h, w), a square crop (size, size) is
391
            made.
392
    """
393

soumith's avatar
soumith committed
394
    def __init__(self, size):
395
396
397
398
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
399

400
401
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
402
403
404
405
406
407
408
        """Get parameters for ``crop`` for center crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
409
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
410
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
411
        w, h = img.size
412
        th, tw = output_size
413
414
415
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
416

soumith's avatar
soumith committed
417
    def __call__(self, img):
418
419
420
421
422
423
424
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
425
426
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
427
428


429
class Pad(object):
430
431
432
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
433
434
435
436
437
438
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
439
            length 3, it is used to fill R, G, B channels respectively.
440
    """
441

442
    def __init__(self, padding, fill=0):
443
444
445
446
447
448
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

449
450
451
452
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
453
454
455
456
457
458
459
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
460
        return pad(img, self.padding, self.fill)
461

462

Soumith Chintala's avatar
Soumith Chintala committed
463
class Lambda(object):
464
465
466
467
468
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
469

Soumith Chintala's avatar
Soumith Chintala committed
470
    def __init__(self, lambd):
471
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
472
473
474
475
476
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

477

soumith's avatar
soumith committed
478
class RandomCrop(object):
479
480
481
482
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
483
            int instead of sequence like (h, w), a square crop (size, size) is
484
485
486
487
488
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
489
    """
490

soumith's avatar
soumith committed
491
    def __init__(self, size, padding=0):
492
493
494
495
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
496
497
        self.padding = padding

498
499
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
500
501
502
503
504
505
506
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
507
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
508
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
509
        w, h = img.size
510
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
511
512
513
        if w == tw and h == th:
            return img

514
515
516
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
517

soumith's avatar
soumith committed
518
    def __call__(self, img):
519
520
521
522
523
524
525
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
526
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
527
            img = pad(img, self.padding)
soumith's avatar
soumith committed
528

529
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
530

531
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
532
533
534


class RandomHorizontalFlip(object):
535
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
536

soumith's avatar
soumith committed
537
    def __call__(self, img):
538
539
540
541
542
543
544
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
545
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
546
            return hflip(img)
soumith's avatar
soumith committed
547
548
549
        return img


550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
class RandomVerticalFlip(object):
    """Vertically flip the given PIL.Image randomly with a probability of 0.5"""

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_TOP_BOTTOM)
        return img


soumith's avatar
soumith committed
566
class RandomSizedCrop(object):
567
568
569
570
571
572
573
574
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
575
        size: expected output size of each edge
576
        interpolation: Default: PIL.Image.BILINEAR
577
    """
578

soumith's avatar
soumith committed
579
    def __init__(self, size, interpolation=Image.BILINEAR):
580
        self.size = (size, size)
soumith's avatar
soumith committed
581
582
        self.interpolation = interpolation

583
584
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
585
586
587
588
589
590
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
591
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
592
593
                sized crop.
        """
soumith's avatar
soumith committed
594
595
596
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
597
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
598
599
600
601
602
603
604
605

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
606
607
608
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
609
610

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
611
        w = min(img.size[0], img.shape[1])
612
613
614
        i = (img.shape[1] - w) // 2
        j = (img.shape[0] - w) // 2
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
615
616

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
617
618
619
620
621
622
623
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly cropped and scaled image.
        """
624
625
        i, j, h, w = self.get_params(img)
        return scaled_crop(img, i, j, h, w, self.size, self.interpolation)