Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
459dc59e
Commit
459dc59e
authored
Sep 26, 2017
by
Soumith Chintala
Committed by
GitHub
Sep 26, 2017
Browse files
Merge pull request #240 from chsasank/master
Refactor of transforms
parents
addfbd1d
2cc58ed0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
332 additions
and
105 deletions
+332
-105
torchvision/transforms.py
torchvision/transforms.py
+332
-105
No files found.
torchvision/transforms.py
View file @
459dc59e
...
...
@@ -13,43 +13,35 @@ import types
import
collections
class
Compose
(
object
):
"""Composes several transforms together.
def
_is_pil_image
(
img
):
if
accimage
is
not
None
:
return
isinstance
(
img
,
(
Image
.
Image
,
accimage
.
Image
))
else
:
return
isinstance
(
img
,
Image
.
Image
)
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def
_is_tensor_image
(
img
):
return
torch
.
is_tensor
(
img
)
and
img
.
ndimension
()
==
3
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
img
):
for
t
in
self
.
transforms
:
img
=
t
(
img
)
return
img
def
_is_numpy_image
(
img
):
return
isinstance
(
img
,
np
.
ndarray
)
and
(
img
.
ndim
in
{
2
,
3
})
class
ToTensor
(
object
):
def
to_tensor
(
pic
):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
See ``ToTensor`` for more details.
def
__call__
(
self
,
pic
):
"""
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if
not
(
_is_pil_image
(
pic
)
or
_is_numpy_image
(
pic
)):
raise
TypeError
(
'pic should be PIL Image or ndarray. Got {}'
.
format
(
type
(
pic
)))
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
img
=
torch
.
from_numpy
(
pic
.
transpose
((
2
,
0
,
1
)))
...
...
@@ -85,29 +77,27 @@ class ToTensor(object):
return
img
class
ToPILImage
(
object
):
"""Convert a tensor to PIL Image.
def
to_pil_image
(
pic
):
"""Convert a tensor
or an ndarray
to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
"""
See ``ToPIlImage`` for more details.
def
__call__
(
self
,
pic
):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
if
not
(
_is_numpy_image
(
pic
)
or
_is_tensor_image
(
pic
)):
raise
TypeError
(
'pic should be Tensor or ndarray. Got {}.'
.
format
(
type
(
pic
)))
npimg
=
pic
mode
=
None
if
isinstance
(
pic
,
torch
.
FloatTensor
):
pic
=
pic
.
mul
(
255
).
byte
()
if
torch
.
is_tensor
(
pic
):
npimg
=
np
.
transpose
(
pic
.
numpy
(),
(
1
,
2
,
0
))
assert
isinstance
(
npimg
,
np
.
ndarray
)
,
'pic should be Tensor or ndarray'
assert
isinstance
(
npimg
,
np
.
ndarray
)
if
npimg
.
shape
[
2
]
==
1
:
npimg
=
npimg
[:,
:,
0
]
...
...
@@ -129,6 +119,212 @@ class ToPILImage(object):
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
def
normalize
(
tensor
,
mean
,
std
):
"""Normalize an tensor image with mean and standard deviation.
See ``Normalize`` for more details.
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for R, G, B channels respecitvely.
std (sequence): Sequence of standard deviations for R, G, B channels
respecitvely.
Returns:
Tensor: Normalized image.
"""
if
not
_is_tensor_image
(
tensor
):
raise
TypeError
(
'tensor is not a torch image.'
)
# TODO: make efficient
for
t
,
m
,
s
in
zip
(
tensor
,
mean
,
std
):
t
.
sub_
(
m
).
div_
(
s
)
return
tensor
def
scale
(
img
,
size
,
interpolation
=
Image
.
BILINEAR
):
"""Rescale the input PIL.Image to the given size.
Args:
img (PIL.Image): Image to be scaled.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
Returns:
PIL.Image: Rescaled image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
if
not
(
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
collections
.
Iterable
)
and
len
(
size
)
==
2
)):
raise
TypeError
(
'Got inappropriate size arg: {}'
.
format
(
size
))
if
isinstance
(
size
,
int
):
w
,
h
=
img
.
size
if
(
w
<=
h
and
w
==
size
)
or
(
h
<=
w
and
h
==
size
):
return
img
if
w
<
h
:
ow
=
size
oh
=
int
(
size
*
h
/
w
)
return
img
.
resize
((
ow
,
oh
),
interpolation
)
else
:
oh
=
size
ow
=
int
(
size
*
w
/
h
)
return
img
.
resize
((
ow
,
oh
),
interpolation
)
else
:
return
img
.
resize
(
size
[::
-
1
],
interpolation
)
def
pad
(
img
,
padding
,
fill
=
0
):
"""Pad the given PIL.Image on all sides with the given "pad" value.
Args:
img (PIL.Image): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill: Pixel fill value. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
Returns:
PIL.Image: Padded image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
if
not
isinstance
(
padding
,
(
numbers
.
Number
,
tuple
)):
raise
TypeError
(
'Got inappropriate padding arg'
)
if
not
isinstance
(
fill
,
(
numbers
.
Number
,
str
,
tuple
)):
raise
TypeError
(
'Got inappropriate fill arg'
)
if
isinstance
(
padding
,
collections
.
Sequence
)
and
len
(
padding
)
not
in
[
2
,
4
]:
raise
ValueError
(
"Padding must be an int or a 2, or 4 element tuple, not a "
+
"{} element tuple"
.
format
(
len
(
padding
)))
return
ImageOps
.
expand
(
img
,
border
=
padding
,
fill
=
fill
)
def
crop
(
img
,
i
,
j
,
h
,
w
):
"""Crop the given PIL.Image.
Args:
img (PIL.Image): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
Returns:
PIL.Image: Cropped image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
return
img
.
crop
((
j
,
i
,
j
+
w
,
i
+
h
))
def
scaled_crop
(
img
,
i
,
j
,
h
,
w
,
size
,
interpolation
=
Image
.
BILINEAR
):
"""Crop the given PIL.Image and scale it to desired size.
Notably used in RandomSizedCrop.
Args:
img (PIL.Image): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
size (sequence or int): Desired output size. Same semantics as ``scale``.
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``.
Returns:
PIL.Image: Cropped image.
"""
assert
_is_pil_image
(
img
),
'img should be PIL Image'
img
=
crop
(
img
,
i
,
j
,
h
,
w
)
img
=
scale
(
img
,
size
,
interpolation
)
return
img
def
hflip
(
img
):
"""Horizontally flip the given PIL.Image.
Args:
img (PIL.Image): Image to be flipped.
Returns:
PIL.Image: Horizontall flipped image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
class
Compose
(
object
):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
def
__call__
(
self
,
img
):
for
t
in
self
.
transforms
:
img
=
t
(
img
)
return
img
class
ToTensor
(
object
):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def
__call__
(
self
,
pic
):
"""
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return
to_tensor
(
pic
)
class
ToPILImage
(
object
):
"""Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
"""
def
__call__
(
self
,
pic
):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
return
to_pil_image
(
pic
)
class
Normalize
(
object
):
"""Normalize an tensor image with mean and standard deviation.
...
...
@@ -154,10 +350,7 @@ class Normalize(object):
Returns:
Tensor: Normalized image.
"""
# TODO: make efficient
for
t
,
m
,
s
in
zip
(
tensor
,
self
.
mean
,
self
.
std
):
t
.
sub_
(
m
).
div_
(
s
)
return
tensor
return
normalize
(
tensor
,
self
.
mean
,
self
.
std
)
class
Scale
(
object
):
...
...
@@ -186,20 +379,7 @@ class Scale(object):
Returns:
PIL.Image: Rescaled image.
"""
if
isinstance
(
self
.
size
,
int
):
w
,
h
=
img
.
size
if
(
w
<=
h
and
w
==
self
.
size
)
or
(
h
<=
w
and
h
==
self
.
size
):
return
img
if
w
<
h
:
ow
=
self
.
size
oh
=
int
(
self
.
size
*
h
/
w
)
return
img
.
resize
((
ow
,
oh
),
self
.
interpolation
)
else
:
oh
=
self
.
size
ow
=
int
(
self
.
size
*
w
/
h
)
return
img
.
resize
((
ow
,
oh
),
self
.
interpolation
)
else
:
return
img
.
resize
(
self
.
size
[::
-
1
],
self
.
interpolation
)
return
scale
(
img
,
self
.
size
,
self
.
interpolation
)
class
CenterCrop
(
object
):
...
...
@@ -217,6 +397,23 @@ class CenterCrop(object):
else
:
self
.
size
=
size
@
staticmethod
def
get_params
(
img
,
output_size
):
"""Get parameters for ``crop`` for center crop.
Args:
img (PIL.Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
"""
w
,
h
=
img
.
size
th
,
tw
=
output_size
i
=
int
(
round
((
h
-
th
)
/
2.
))
j
=
int
(
round
((
w
-
tw
)
/
2.
))
return
i
,
j
,
th
,
tw
def
__call__
(
self
,
img
):
"""
Args:
...
...
@@ -225,11 +422,8 @@ class CenterCrop(object):
Returns:
PIL.Image: Cropped image.
"""
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
x1
=
int
(
round
((
w
-
tw
)
/
2.
))
y1
=
int
(
round
((
h
-
th
)
/
2.
))
return
img
.
crop
((
x1
,
y1
,
x1
+
tw
,
y1
+
th
))
i
,
j
,
h
,
w
=
self
.
get_params
(
img
,
self
.
size
)
return
crop
(
img
,
i
,
j
,
h
,
w
)
class
Pad
(
object
):
...
...
@@ -263,7 +457,7 @@ class Pad(object):
Returns:
PIL.Image: Padded image.
"""
return
ImageOps
.
ex
pa
n
d
(
img
,
border
=
self
.
padding
,
fill
=
self
.
fill
)
return
pad
(
img
,
self
.
padding
,
self
.
fill
)
class
Lambda
(
object
):
...
...
@@ -301,6 +495,26 @@ class RandomCrop(object):
self
.
size
=
size
self
.
padding
=
padding
@
staticmethod
def
get_params
(
img
,
output_size
):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL.Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w
,
h
=
img
.
size
th
,
tw
=
output_size
if
w
==
tw
and
h
==
th
:
return
img
i
=
random
.
randint
(
0
,
h
-
th
)
j
=
random
.
randint
(
0
,
w
-
tw
)
return
i
,
j
,
th
,
tw
def
__call__
(
self
,
img
):
"""
Args:
...
...
@@ -310,16 +524,11 @@ class RandomCrop(object):
PIL.Image: Cropped image.
"""
if
self
.
padding
>
0
:
img
=
ImageOps
.
ex
pa
n
d
(
img
,
border
=
self
.
padding
,
fill
=
0
)
img
=
pad
(
img
,
self
.
padding
)
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
if
w
==
tw
and
h
==
th
:
return
img
i
,
j
,
h
,
w
=
self
.
get_params
(
img
,
self
.
size
)
x1
=
random
.
randint
(
0
,
w
-
tw
)
y1
=
random
.
randint
(
0
,
h
-
th
)
return
img
.
crop
((
x1
,
y1
,
x1
+
tw
,
y1
+
th
))
return
crop
(
img
,
i
,
j
,
h
,
w
)
class
RandomHorizontalFlip
(
object
):
...
...
@@ -334,7 +543,7 @@ class RandomHorizontalFlip(object):
PIL.Image: Randomly flipped image.
"""
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
hflip
(
img
)
return
img
...
...
@@ -347,15 +556,25 @@ class RandomSizedCrop(object):
This is popularly used to train the Inception networks.
Args:
size:
size of the smaller
edge
size:
expected output size of each
edge
interpolation: Default: PIL.Image.BILINEAR
"""
def
__init__
(
self
,
size
,
interpolation
=
Image
.
BILINEAR
):
self
.
size
=
size
self
.
size
=
(
size
,
size
)
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
):
@
staticmethod
def
get_params
(
img
):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL.Image): Image to be cropped.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
for
attempt
in
range
(
10
):
area
=
img
.
size
[
0
]
*
img
.
size
[
1
]
target_area
=
random
.
uniform
(
0.08
,
1.0
)
*
area
...
...
@@ -368,15 +587,23 @@ class RandomSizedCrop(object):
w
,
h
=
h
,
w
if
w
<=
img
.
size
[
0
]
and
h
<=
img
.
size
[
1
]:
x1
=
random
.
randint
(
0
,
img
.
size
[
0
]
-
w
)
y1
=
random
.
randint
(
0
,
img
.
size
[
1
]
-
h
)
i
=
random
.
randint
(
0
,
img
.
size
[
1
]
-
h
)
j
=
random
.
randint
(
0
,
img
.
size
[
0
]
-
w
)
return
i
,
j
,
h
,
w
img
=
img
.
crop
((
x1
,
y1
,
x1
+
w
,
y1
+
h
))
assert
(
img
.
size
==
(
w
,
h
))
# Fallback
w
=
min
(
img
.
size
[
0
],
img
.
shape
[
1
])
i
=
(
img
.
shape
[
1
]
-
w
)
//
2
j
=
(
img
.
shape
[
0
]
-
w
)
//
2
return
i
,
j
,
w
,
w
return
img
.
resize
((
self
.
size
,
self
.
size
),
self
.
interpolation
)
def
__call__
(
self
,
img
):
"""
Args:
img (PIL.Image): Image to be flipped.
# Fallback
scale
=
Scale
(
self
.
size
,
interpolation
=
self
.
interpolation
)
crop
=
CenterCrop
(
self
.
size
)
return
crop
(
scale
(
img
))
Returns:
PIL.Image: Randomly cropped and scaled image.
"""
i
,
j
,
h
,
w
=
self
.
get_params
(
img
)
return
scaled_crop
(
img
,
i
,
j
,
h
,
w
,
self
.
size
,
self
.
interpolation
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment