Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
679f02c2
Commit
679f02c2
authored
Oct 15, 2017
by
Hang Zhang
Browse files
unuse
parent
e6386d0b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
126 deletions
+0
-126
experiments/dataset/dataloader.py
experiments/dataset/dataloader.py
+0
-126
No files found.
experiments/dataset/dataloader.py
deleted
100644 → 0
View file @
e6386d0b
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# refer to https://github.com/pytorch/vision/blob/master/torchvision/
import
torch.utils.data
as
data
import
torchvision
from
PIL
import
Image
import
os
import
os.path
IMG_EXTENSIONS
=
[
'.jpg'
,
'.JPG'
,
'.jpeg'
,
'.JPEG'
,
'.png'
,
'.PNG'
,
'.ppm'
,
'.PPM'
,
'.bmp'
,
'.BMP'
,
]
def
is_image_file
(
filename
):
return
any
(
filename
.
endswith
(
extension
)
for
extension
in
IMG_EXTENSIONS
)
def
find_classes
(
dir
):
classes
=
[
d
for
d
in
os
.
listdir
(
dir
)
if
os
.
path
.
isdir
(
os
.
path
.
join
(
dir
,
d
))]
classes
.
sort
()
class_to_idx
=
{
classes
[
i
]:
i
for
i
in
range
(
len
(
classes
))}
return
classes
,
class_to_idx
def
make_dataset
(
dir
,
class_to_idx
):
images
=
[]
for
target
in
os
.
listdir
(
dir
):
d
=
os
.
path
.
join
(
dir
,
target
,
'images'
)
if
not
os
.
path
.
isdir
(
d
):
continue
for
root
,
_
,
fnames
in
sorted
(
os
.
walk
(
d
)):
for
fname
in
fnames
:
if
is_image_file
(
fname
):
path
=
os
.
path
.
join
(
root
,
fname
)
item
=
(
path
,
class_to_idx
[
target
])
images
.
append
(
item
)
return
images
def
default_loader
(
path
):
return
Image
.
open
(
path
).
convert
(
'RGB'
)
class
DatasetLoader
(
data
.
Dataset
):
def
__init__
(
self
,
root
,
transform
=
None
,
target_transform
=
None
,
loader
=
default_loader
):
classes
,
class_to_idx
=
find_classes
(
root
)
imgs
=
make_dataset
(
root
,
class_to_idx
)
if
len
(
imgs
)
==
0
:
raise
(
RuntimeError
(
"Found 0 images in subfolders of: "
+
root
\
+
"
\n
Supported image extensions are: "
+
\
","
.
join
(
IMG_EXTENSIONS
)))
self
.
root
=
root
self
.
imgs
=
imgs
self
.
classes
=
classes
self
.
class_to_idx
=
class_to_idx
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
loader
=
loader
def
__getitem__
(
self
,
index
):
path
,
target
=
self
.
imgs
[
index
]
img
=
self
.
loader
(
path
)
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
__len__
(
self
):
return
len
(
self
.
imgs
)
def
annotation_reader
(
root
,
class_to_idx
):
# read the tiny imagenet annotations.txt and returns the imgs and class
file
=
open
(
os
.
path
.
join
(
root
,
'val_annotations.txt'
),
'r'
)
images
=
[]
for
line
in
file
:
sp
=
line
.
split
(
'
\t
'
)
path
=
os
.
path
.
join
(
root
,
'images'
,
sp
[
0
])
item
=
[
path
,
class_to_idx
[
sp
[
1
]]]
images
.
append
(
item
)
return
images
class
ValDatasetLoader
(
data
.
Dataset
):
def
__init__
(
self
,
root
,
classes
,
class_to_idx
,
transform
=
None
,
target_transform
=
None
,
loader
=
default_loader
):
imgs
=
annotation_reader
(
root
,
class_to_idx
)
self
.
root
=
root
self
.
imgs
=
imgs
self
.
classes
=
classes
self
.
class_to_idx
=
class_to_idx
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
loader
=
loader
def
__getitem__
(
self
,
index
):
path
,
target
=
self
.
imgs
[
index
]
img
=
self
.
loader
(
path
)
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
__len__
(
self
):
return
len
(
self
.
imgs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment