Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4390b559
Commit
4390b559
authored
Sep 16, 2017
by
Sasank Chilamkurthy
Browse files
Make get_params static method
parent
8b18f526
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
9 deletions
+12
-9
torchvision/transforms.py
torchvision/transforms.py
+12
-9
No files found.
torchvision/transforms.py
View file @
4390b559
...
@@ -67,7 +67,7 @@ def to_tensor(pic):
...
@@ -67,7 +67,7 @@ def to_tensor(pic):
return
img
return
img
def
to_pilimage
(
pic
):
def
to_pil
_
image
(
pic
):
if
not
(
_is_numpy_image
(
pic
)
or
_is_tensor_image
(
pic
)):
if
not
(
_is_numpy_image
(
pic
)
or
_is_tensor_image
(
pic
)):
raise
TypeError
(
'pic should be Tensor or ndarray. Got {}.'
.
format
(
type
(
pic
)))
raise
TypeError
(
'pic should be Tensor or ndarray. Got {}.'
.
format
(
type
(
pic
)))
...
@@ -219,7 +219,7 @@ class ToPILImage(object):
...
@@ -219,7 +219,7 @@ class ToPILImage(object):
PIL.Image: Image converted to PIL.Image.
PIL.Image: Image converted to PIL.Image.
"""
"""
return
to_pilimage
(
pic
)
return
to_pil
_
image
(
pic
)
class
Normalize
(
object
):
class
Normalize
(
object
):
...
@@ -294,9 +294,10 @@ class CenterCrop(object):
...
@@ -294,9 +294,10 @@ class CenterCrop(object):
else
:
else
:
self
.
size
=
size
self
.
size
=
size
def
get_params
(
self
,
img
):
@
staticmethod
def
get_params
(
img
,
output_size
):
w
,
h
=
img
.
size
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
th
,
tw
=
output_
size
x1
=
int
(
round
((
w
-
tw
)
/
2.
))
x1
=
int
(
round
((
w
-
tw
)
/
2.
))
y1
=
int
(
round
((
h
-
th
)
/
2.
))
y1
=
int
(
round
((
h
-
th
)
/
2.
))
return
x1
,
y1
,
tw
,
th
return
x1
,
y1
,
tw
,
th
...
@@ -309,7 +310,7 @@ class CenterCrop(object):
...
@@ -309,7 +310,7 @@ class CenterCrop(object):
Returns:
Returns:
PIL.Image: Cropped image.
PIL.Image: Cropped image.
"""
"""
x1
,
y1
,
tw
,
th
=
self
.
get_params
(
img
)
x1
,
y1
,
tw
,
th
=
self
.
get_params
(
img
,
self
.
size
)
return
crop
(
img
,
x1
,
y1
,
tw
,
th
)
return
crop
(
img
,
x1
,
y1
,
tw
,
th
)
...
@@ -382,9 +383,10 @@ class RandomCrop(object):
...
@@ -382,9 +383,10 @@ class RandomCrop(object):
self
.
size
=
size
self
.
size
=
size
self
.
padding
=
padding
self
.
padding
=
padding
def
get_params
(
self
,
img
):
@
staticmethod
def
get_params
(
img
,
output_size
):
w
,
h
=
img
.
size
w
,
h
=
img
.
size
th
,
tw
=
self
.
size
th
,
tw
=
output_
size
if
w
==
tw
and
h
==
th
:
if
w
==
tw
and
h
==
th
:
return
img
return
img
...
@@ -403,7 +405,7 @@ class RandomCrop(object):
...
@@ -403,7 +405,7 @@ class RandomCrop(object):
if
self
.
padding
>
0
:
if
self
.
padding
>
0
:
img
=
pad
(
img
,
self
.
padding
)
img
=
pad
(
img
,
self
.
padding
)
x1
,
y1
,
tw
,
th
=
self
.
get_params
(
img
)
x1
,
y1
,
tw
,
th
=
self
.
get_params
(
img
,
self
.
size
)
return
crop
(
img
,
x1
,
y1
,
tw
,
th
)
return
crop
(
img
,
x1
,
y1
,
tw
,
th
)
...
@@ -441,7 +443,8 @@ class RandomSizedCrop(object):
...
@@ -441,7 +443,8 @@ class RandomSizedCrop(object):
self
.
size
=
size
self
.
size
=
size
self
.
interpolation
=
interpolation
self
.
interpolation
=
interpolation
def
get_params
(
self
,
img
):
@
staticmethod
def
get_params
(
img
):
for
attempt
in
range
(
10
):
for
attempt
in
range
(
10
):
area
=
img
.
size
[
0
]
*
img
.
size
[
1
]
area
=
img
.
size
[
0
]
*
img
.
size
[
1
]
target_area
=
random
.
uniform
(
0.08
,
1.0
)
*
area
target_area
=
random
.
uniform
(
0.08
,
1.0
)
*
area
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment