Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6f342d3a
Commit
6f342d3a
authored
Jan 19, 2017
by
Soumith Chintala
Committed by
GitHub
Jan 19, 2017
Browse files
Merge pull request #33 from pytorch/improvements
Minor improvements
parents
9896626a
72cd478e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
15 deletions
+20
-15
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+2
-3
torchvision/transforms.py
torchvision/transforms.py
+18
-12
No files found.
torchvision/datasets/mnist.py
View file @
6f342d3a
...
@@ -72,7 +72,6 @@ class MNIST(data.Dataset):
...
@@ -72,7 +72,6 @@ class MNIST(data.Dataset):
import
gzip
import
gzip
if
self
.
_check_exists
():
if
self
.
_check_exists
():
print
(
'Files already downloaded'
)
return
return
# download files
# download files
...
@@ -98,8 +97,8 @@ class MNIST(data.Dataset):
...
@@ -98,8 +97,8 @@ class MNIST(data.Dataset):
os
.
unlink
(
file_path
)
os
.
unlink
(
file_path
)
# process and save as torch files
# process and save as torch files
print
(
'Processing'
)
print
(
'Processing
...
'
)
training_set
=
(
training_set
=
(
read_image_file
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_folder
,
'train-images-idx3-ubyte'
)),
read_image_file
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_folder
,
'train-images-idx3-ubyte'
)),
read_label_file
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_folder
,
'train-labels-idx1-ubyte'
))
read_label_file
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_folder
,
'train-labels-idx1-ubyte'
))
...
...
torchvision/transforms.py
View file @
6f342d3a
...
@@ -8,12 +8,16 @@ import numbers
...
@@ -8,12 +8,16 @@ import numbers
import
types
import
types
class
Compose
(
object
):
class
Compose
(
object
):
""" Composes several transforms together.
"""Composes several transforms together.
For example:
>>> transforms.Compose([
Args:
>>> transforms.CenterCrop(10),
transforms (List[Transform]): list of transforms to compose.
>>> transforms.ToTensor(),
>>> ])
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
"""
def
__init__
(
self
,
transforms
):
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
self
.
transforms
=
transforms
...
@@ -25,8 +29,9 @@ class Compose(object):
...
@@ -25,8 +29,9 @@ class Compose(object):
class
ToTensor
(
object
):
class
ToTensor
(
object
):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
"""Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def
__call__
(
self
,
pic
):
def
__call__
(
self
,
pic
):
if
isinstance
(
pic
,
np
.
ndarray
):
if
isinstance
(
pic
,
np
.
ndarray
):
# handle numpy array
# handle numpy array
...
@@ -40,8 +45,9 @@ class ToTensor(object):
...
@@ -40,8 +45,9 @@ class ToTensor(object):
img
=
img
.
transpose
(
0
,
1
).
transpose
(
0
,
2
).
contiguous
()
img
=
img
.
transpose
(
0
,
1
).
transpose
(
0
,
2
).
contiguous
()
return
img
.
float
().
div
(
255
)
return
img
.
float
().
div
(
255
)
class
ToPILImage
(
object
):
class
ToPILImage
(
object
):
"""
Converts a torch.*Tensor of range [0, 1] and shape C x H x W
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
to a PIL.Image of range [0, 255]
"""
"""
...
@@ -56,7 +62,7 @@ class ToPILImage(object):
...
@@ -56,7 +62,7 @@ class ToPILImage(object):
return
img
return
img
class
Normalize
(
object
):
class
Normalize
(
object
):
"""
Given mean: (R, G, B) and std: (R, G, B),
"""Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e.
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
channel = (channel - mean) / std
"""
"""
...
@@ -72,7 +78,7 @@ class Normalize(object):
...
@@ -72,7 +78,7 @@ class Normalize(object):
class
Scale
(
object
):
class
Scale
(
object
):
"""
Rescales the input PIL.Image to the given 'size'.
"""Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
For example, if height > width, then image will be
rescaled to (size * height / width, size)
rescaled to (size * height / width, size)
...
@@ -128,7 +134,7 @@ class Pad(object):
...
@@ -128,7 +134,7 @@ class Pad(object):
return
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
fill
=
self
.
fill
)
return
ImageOps
.
expand
(
img
,
border
=
self
.
padding
,
fill
=
self
.
fill
)
class
Lambda
(
object
):
class
Lambda
(
object
):
"""Applies a lambda as a transform"""
"""Applies a lambda as a transform
.
"""
def
__init__
(
self
,
lambd
):
def
__init__
(
self
,
lambd
):
assert
type
(
lambd
)
is
types
.
LambdaType
assert
type
(
lambd
)
is
types
.
LambdaType
self
.
lambd
=
lambd
self
.
lambd
=
lambd
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment