Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5a2bbc57
Commit
5a2bbc57
authored
Sep 03, 2017
by
Sasank Chilamkurthy
Browse files
First cut refactoring
(cherry picked from commit 71afec427baca8e37cd9e10d98812bc586e9a4ac)
parent
8e375670
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
146 additions
and
100 deletions
+146
-100
torchvision/transforms.py
torchvision/transforms.py
+146
-100
No files found.
torchvision/transforms.py
View file @
5a2bbc57
...
...
@@ -13,6 +13,112 @@ import types
import
collections
def
to_tensor
(
pic
):
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
img
=
torch
.
from_numpy
(
pic
.
transpose
((
2
,
0
,
1
)))
# backward compatibility
return
img
.
float
().
div
(
255
)
if
accimage
is
not
None
and
isinstance
(
pic
,
accimage
.
Image
):
nppic
=
np
.
zeros
([
pic
.
channels
,
pic
.
height
,
pic
.
width
],
dtype
=
np
.
float32
)
pic
.
copyto
(
nppic
)
return
torch
.
from_numpy
(
nppic
)
# handle PIL Image
if
pic
.
mode
==
'I'
:
img
=
torch
.
from_numpy
(
np
.
array
(
pic
,
np
.
int32
,
copy
=
False
))
elif
pic
.
mode
==
'I;16'
:
img
=
torch
.
from_numpy
(
np
.
array
(
pic
,
np
.
int16
,
copy
=
False
))
else
:
img
=
torch
.
ByteTensor
(
torch
.
ByteStorage
.
from_buffer
(
pic
.
tobytes
()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if
pic
.
mode
==
'YCbCr'
:
nchannel
=
3
elif
pic
.
mode
==
'I;16'
:
nchannel
=
1
else
:
nchannel
=
len
(
pic
.
mode
)
img
=
img
.
view
(
pic
.
size
[
1
],
pic
.
size
[
0
],
nchannel
)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img
=
img
.
transpose
(
0
,
1
).
transpose
(
0
,
2
).
contiguous
()
if
isinstance
(
img
,
torch
.
ByteTensor
):
return
img
.
float
().
div
(
255
)
else
:
return
img
def
to_pilimage
(
pic
):
npimg
=
pic
mode
=
None
if
isinstance
(
pic
,
torch
.
FloatTensor
):
pic
=
pic
.
mul
(
255
).
byte
()
if
torch
.
is_tensor
(
pic
):
npimg
=
np
.
transpose
(
pic
.
numpy
(),
(
1
,
2
,
0
))
assert
isinstance
(
npimg
,
np
.
ndarray
),
'pic should be Tensor or ndarray'
if
npimg
.
shape
[
2
]
==
1
:
npimg
=
npimg
[:,
:,
0
]
if
npimg
.
dtype
==
np
.
uint8
:
mode
=
'L'
if
npimg
.
dtype
==
np
.
int16
:
mode
=
'I;16'
if
npimg
.
dtype
==
np
.
int32
:
mode
=
'I'
elif
npimg
.
dtype
==
np
.
float32
:
mode
=
'F'
else
:
if
npimg
.
dtype
==
np
.
uint8
:
mode
=
'RGB'
assert
mode
is
not
None
,
'{} is not supported'
.
format
(
npimg
.
dtype
)
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
def
normalize
(
tensor
,
mean
,
std
):
# TODO: make efficient
for
t
,
m
,
s
in
zip
(
tensor
,
mean
,
std
):
t
.
sub_
(
m
).
div_
(
s
)
return
tensor
def
scale
(
img
,
size
,
interpolation
=
Image
.
BILINEAR
):
assert
isinstance
(
size
,
int
)
or
(
isinstance
(
size
,
collections
.
Iterable
)
and
len
(
size
)
==
2
)
if
isinstance
(
size
,
int
):
w
,
h
=
img
.
size
if
(
w
<=
h
and
w
==
size
)
or
(
h
<=
w
and
h
==
size
):
return
img
if
w
<
h
:
ow
=
size
oh
=
int
(
size
*
h
/
w
)
return
img
.
resize
((
ow
,
oh
),
interpolation
)
else
:
oh
=
size
ow
=
int
(
size
*
w
/
h
)
return
img
.
resize
((
ow
,
oh
),
interpolation
)
else
:
return
img
.
resize
(
size
,
interpolation
)
def
pad
(
img
,
padding
,
fill
=
0
):
assert
isinstance
(
padding
,
numbers
.
Number
)
assert
isinstance
(
fill
,
numbers
.
Number
)
or
isinstance
(
fill
,
str
)
or
isinstance
(
fill
,
tuple
)
return
ImageOps
.
expand
(
img
,
border
=
padding
,
fill
=
fill
)
def
crop
(
img
,
x
,
y
,
w
,
h
):
return
img
.
crop
((
x
,
y
,
x
+
w
,
y
+
h
))
def
scaled_crop
(
img
,
x
,
y
,
w
,
h
,
size
,
interpolation
=
Image
.
BILINEAR
):
img
=
crop
(
img
,
x
,
y
,
w
,
h
)
img
=
scale
(
img
,
size
,
interpolation
)
def
hflip
(
img
):
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
class
Compose
(
object
):
"""Composes several transforms together.
...
...
@@ -50,39 +156,7 @@ class ToTensor(object):
Returns:
Tensor: Converted image.
"""
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
img
=
torch
.
from_numpy
(
pic
.
transpose
((
2
,
0
,
1
)))
# backward compatibility
return
img
.
float
().
div
(
255
)
if
accimage
is
not
None
and
isinstance
(
pic
,
accimage
.
Image
):
nppic
=
np
.
zeros
([
pic
.
channels
,
pic
.
height
,
pic
.
width
],
dtype
=
np
.
float32
)
pic
.
copyto
(
nppic
)
return
torch
.
from_numpy
(
nppic
)
# handle PIL Image
if
pic
.
mode
==
'I'
:
img
=
torch
.
from_numpy
(
np
.
array
(
pic
,
np
.
int32
,
copy
=
False
))
elif
pic
.
mode
==
'I;16'
:
img
=
torch
.
from_numpy
(
np
.
array
(
pic
,
np
.
int16
,
copy
=
False
))
else
:
img
=
torch
.
ByteTensor
(
torch
.
ByteStorage
.
from_buffer
(
pic
.
tobytes
()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if
pic
.
mode
==
'YCbCr'
:
nchannel
=
3
elif
pic
.
mode
==
'I;16'
:
nchannel
=
1
else
:
nchannel
=
len
(
pic
.
mode
)
img
=
img
.
view
(
pic
.
size
[
1
],
pic
.
size
[
0
],
nchannel
)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img
=
img
.
transpose
(
0
,
1
).
transpose
(
0
,
2
).
contiguous
()
if
isinstance
(
img
,
torch
.
ByteTensor
):
return
img
.
float
().
div
(
255
)
else
:
return
img
return
to_tensor
(
pic
)
class
ToPILImage
(
object
):
...
...
@@ -101,29 +175,7 @@ class ToPILImage(object):
PIL.Image: Image converted to PIL.Image.
"""
npimg
=
pic
mode
=
None
if
isinstance
(
pic
,
torch
.
FloatTensor
):
pic
=
pic
.
mul
(
255
).
byte
()
if
torch
.
is_tensor
(
pic
):
npimg
=
np
.
transpose
(
pic
.
numpy
(),
(
1
,
2
,
0
))
assert
isinstance
(
npimg
,
np
.
ndarray
),
'pic should be Tensor or ndarray'
if
npimg
.
shape
[
2
]
==
1
:
npimg
=
npimg
[:,
:,
0
]
if
npimg
.
dtype
==
np
.
uint8
:
mode
=
'L'
if
npimg
.
dtype
==
np
.
int16
:
mode
=
'I;16'
if
npimg
.
dtype
==
np
.
int32
:
mode
=
'I'
elif
npimg
.
dtype
==
np
.
float32
:
mode
=
'F'
else
:
if
npimg
.
dtype
==
np
.
uint8
:
mode
=
'RGB'
assert
mode
is
not
None
,
'{} is not supported'
.
format
(
npimg
.
dtype
)
return
Image
.
fromarray
(
npimg
,
mode
=
mode
)
return
to_pilimage
(
pic
)
class
Normalize
(
object
):
...
...
@@ -151,10 +203,7 @@ class Normalize(object):
Returns:
Tensor: Normalized image.
"""
# TODO: make efficient
for
t
,
m
,
s
in
zip
(
tensor
,
self
.
mean
,
self
.
std
):
t
.
sub_
(
m
).
div_
(
s
)
return
tensor
return
normalize
(
tensor
,
self
.
mean
,
self
.
std
)
class
Scale
(
object
):
...
...
@@ -183,20 +232,7 @@ class Scale(object):
Returns:
PIL.Image: Rescaled image.
"""
if
isinstance
(
self
.
size
,
int
):
w
,
h
=
img
.
size
if
(
w
<=
h
and
w
==
self
.
size
)
or
(
h
<=
w
and
h
==
self
.
size
):
return
img
if
w
<
h
:
ow
=
self
.
size
oh
=
int
(
self
.
size
*
h
/
w
)
return
img
.
resize
((
ow
,
oh
),
self
.
interpolation
)
else
:
oh
=
self
.
size
ow
=
int
(
self
.
size
*
w
/
h
)
return
img
.
resize
((
ow
,
oh
),
self
.
interpolation
)
else
:
return
img
.
resize
(
self
.
size
,
self
.
interpolation
)
return
scale
(
img
,
self
.
size
,
self
.
interpolation
)
class
CenterCrop
(
object
):
...
...
@@ -214,6 +250,13 @@ class CenterCrop(object):
else
:
self
.
size
=
size
def
get_params
(
self
,
img
):
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
x1
=
int
(
round
((
w
-
tw
)
/
2.
))
y1
=
int
(
round
((
h
-
th
)
/
2.
))
return
x1
,
y1
,
tw
,
th
def
__call__
(
self
,
img
):
"""
Args:
...
...
@@ -222,11 +265,8 @@ class CenterCrop(object):
Returns:
PIL.Image: Cropped image.
"""
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
x1
=
int
(
round
((
w
-
tw
)
/
2.
))
y1
=
int
(
round
((
h
-
th
)
/
2.
))
return
img
.
crop
((
x1
,
y1
,
x1
+
tw
,
y1
+
th
))
x1
,
y1
,
tw
,
th
=
self
.
get_params
(
img
)
return
crop
(
img
,
x1
,
y1
,
tw
,
th
)
class
Pad
(
object
):
...
...
@@ -260,7 +300,7 @@ class Pad(object):
Returns:
PIL.Image: Padded image.
"""
return
ImageOps
.
ex
pa
n
d
(
img
,
border
=
self
.
padding
,
fill
=
self
.
fill
)
return
pad
(
img
,
self
.
padding
,
self
.
fill
)
class
Lambda
(
object
):
...
...
@@ -298,6 +338,16 @@ class RandomCrop(object):
self
.
size
=
size
self
.
padding
=
padding
def
get_params
(
self
,
img
):
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
if
w
==
tw
and
h
==
th
:
return
img
x1
=
random
.
randint
(
0
,
w
-
tw
)
y1
=
random
.
randint
(
0
,
h
-
th
)
return
x1
,
y1
,
tw
,
th
def
__call__
(
self
,
img
):
"""
Args:
...
...
@@ -307,16 +357,11 @@ class RandomCrop(object):
PIL.Image: Cropped image.
"""
if
self
.
padding
>
0
:
img
=
ImageOps
.
ex
pa
n
d
(
img
,
border
=
self
.
padding
,
fill
=
0
)
img
=
pad
(
img
,
self
.
padding
)
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
if
w
==
tw
and
h
==
th
:
return
img
x1
,
y1
,
tw
,
th
=
self
.
get_params
(
img
)
x1
=
random
.
randint
(
0
,
w
-
tw
)
y1
=
random
.
randint
(
0
,
h
-
th
)
return
img
.
crop
((
x1
,
y1
,
x1
+
tw
,
y1
+
th
))
return
crop
(
img
,
x1
,
y1
,
tw
,
th
)
class
RandomHorizontalFlip
(
object
):
...
...
@@ -331,7 +376,7 @@ class RandomHorizontalFlip(object):
PIL.Image: Randomly flipped image.
"""
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
hflip
(
img
)
return
img
...
...
@@ -352,7 +397,7 @@ class RandomSizedCrop(object):
self
.
size
=
size
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
):
def
get_params
(
self
,
img
):
for
attempt
in
range
(
10
):
area
=
img
.
size
[
0
]
*
img
.
size
[
1
]
target_area
=
random
.
uniform
(
0.08
,
1.0
)
*
area
...
...
@@ -365,15 +410,16 @@ class RandomSizedCrop(object):
w
,
h
=
h
,
w
if
w
<=
img
.
size
[
0
]
and
h
<=
img
.
size
[
1
]:
x1
=
random
.
randint
(
0
,
img
.
size
[
0
]
-
w
)
y1
=
random
.
randint
(
0
,
img
.
size
[
1
]
-
h
)
img
=
img
.
crop
((
x1
,
y1
,
x1
+
w
,
y1
+
h
))
assert
(
img
.
size
==
(
w
,
h
))
return
img
.
resize
((
self
.
size
,
self
.
size
),
self
.
interpolation
)
x
=
random
.
randint
(
0
,
img
.
size
[
0
]
-
w
)
y
=
random
.
randint
(
0
,
img
.
size
[
1
]
-
h
)
return
x
,
y
,
w
,
h
# Fallback
scale
=
Scale
(
self
.
size
,
interpolation
=
self
.
interpolation
)
crop
=
CenterCrop
(
self
.
size
)
return
crop
(
scale
(
img
))
w
=
min
(
img
.
size
[
0
],
img
.
shape
[
1
])
x
=
(
img
.
shape
[
0
]
-
w
)
//
2
y
=
(
img
.
shape
[
1
]
-
w
)
//
2
return
x
,
y
,
w
,
w
def
__call__
(
self
,
img
):
x
,
y
,
w
,
h
=
self
.
get_params
(
img
)
return
scaled_crop
(
img
,
x
,
y
,
w
,
h
,
self
.
size
,
self
.
interpolation
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment