Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9563e3e3
Unverified
Commit
9563e3e3
authored
Mar 13, 2024
by
Nicolas Hug
Committed by
GitHub
Mar 13, 2024
Browse files
Add allow_empty parameter to ImageFolder and related utils (#8311)
parent
e00f4e66
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
3 deletions
+43
-3
test/test_datasets.py
test/test_datasets.py
+21
-0
torchvision/datasets/folder.py
torchvision/datasets/folder.py
+22
-3
No files found.
test/test_datasets.py
View file @
9563e3e3
...
@@ -1620,6 +1620,10 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1620,6 +1620,10 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
num_examples_total
+=
num_examples
num_examples_total
+=
num_examples
classes
.
append
(
cls
)
classes
.
append
(
cls
)
if
config
.
pop
(
"make_empty_class"
,
False
):
os
.
makedirs
(
pathlib
.
Path
(
tmpdir
)
/
"empty_class"
)
classes
.
append
(
"empty_class"
)
return
dict
(
num_examples
=
num_examples_total
,
classes
=
classes
)
return
dict
(
num_examples
=
num_examples_total
,
classes
=
classes
)
def
_file_name_fn
(
self
,
cls
,
ext
,
idx
):
def
_file_name_fn
(
self
,
cls
,
ext
,
idx
):
...
@@ -1644,6 +1648,23 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1644,6 +1648,23 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
assert
len
(
dataset
.
classes
)
==
len
(
info
[
"classes"
])
assert
len
(
dataset
.
classes
)
==
len
(
info
[
"classes"
])
assert
all
([
a
==
b
for
a
,
b
in
zip
(
dataset
.
classes
,
info
[
"classes"
])])
assert
all
([
a
==
b
for
a
,
b
in
zip
(
dataset
.
classes
,
info
[
"classes"
])])
def
test_allow_empty
(
self
):
config
=
{
"extensions"
:
self
.
_EXTENSIONS
,
"make_empty_class"
:
True
,
}
config
[
"allow_empty"
]
=
True
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
assert
"empty_class"
in
dataset
.
classes
assert
len
(
dataset
.
classes
)
==
len
(
info
[
"classes"
])
assert
all
([
a
==
b
for
a
,
b
in
zip
(
dataset
.
classes
,
info
[
"classes"
])])
config
[
"allow_empty"
]
=
False
with
pytest
.
raises
(
FileNotFoundError
,
match
=
"Found no valid file"
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
pass
class
ImageFolderTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
ImageFolderTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
ImageFolder
DATASET_CLASS
=
datasets
.
ImageFolder
...
...
torchvision/datasets/folder.py
View file @
9563e3e3
...
@@ -50,6 +50,7 @@ def make_dataset(
...
@@ -50,6 +50,7 @@ def make_dataset(
class_to_idx
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
class_to_idx
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
extensions
:
Optional
[
Union
[
str
,
Tuple
[
str
,
...]]]
=
None
,
extensions
:
Optional
[
Union
[
str
,
Tuple
[
str
,
...]]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
allow_empty
:
bool
=
False
,
)
->
List
[
Tuple
[
str
,
int
]]:
)
->
List
[
Tuple
[
str
,
int
]]:
"""Generates a list of samples of a form (path_to_sample, class).
"""Generates a list of samples of a form (path_to_sample, class).
...
@@ -95,7 +96,7 @@ def make_dataset(
...
@@ -95,7 +96,7 @@ def make_dataset(
available_classes
.
add
(
target_class
)
available_classes
.
add
(
target_class
)
empty_classes
=
set
(
class_to_idx
.
keys
())
-
available_classes
empty_classes
=
set
(
class_to_idx
.
keys
())
-
available_classes
if
empty_classes
:
if
empty_classes
and
not
allow_empty
:
msg
=
f
"Found no valid file for the classes
{
', '
.
join
(
sorted
(
empty_classes
))
}
. "
msg
=
f
"Found no valid file for the classes
{
', '
.
join
(
sorted
(
empty_classes
))
}
. "
if
extensions
is
not
None
:
if
extensions
is
not
None
:
msg
+=
f
"Supported extensions are:
{
extensions
if
isinstance
(
extensions
,
str
)
else
', '
.
join
(
extensions
)
}
"
msg
+=
f
"Supported extensions are:
{
extensions
if
isinstance
(
extensions
,
str
)
else
', '
.
join
(
extensions
)
}
"
...
@@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset):
...
@@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset):
is_valid_file (callable, optional): A function that takes path of a file
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
both extensions and is_valid_file should not be passed.
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Attributes:
Attributes:
classes (list): List of the class names sorted alphabetically.
classes (list): List of the class names sorted alphabetically.
...
@@ -139,10 +142,17 @@ class DatasetFolder(VisionDataset):
...
@@ -139,10 +142,17 @@ class DatasetFolder(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
allow_empty
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
classes
,
class_to_idx
=
self
.
find_classes
(
self
.
root
)
classes
,
class_to_idx
=
self
.
find_classes
(
self
.
root
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
,
extensions
,
is_valid_file
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
=
class_to_idx
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
,
allow_empty
=
allow_empty
,
)
self
.
loader
=
loader
self
.
loader
=
loader
self
.
extensions
=
extensions
self
.
extensions
=
extensions
...
@@ -158,6 +168,7 @@ class DatasetFolder(VisionDataset):
...
@@ -158,6 +168,7 @@ class DatasetFolder(VisionDataset):
class_to_idx
:
Dict
[
str
,
int
],
class_to_idx
:
Dict
[
str
,
int
],
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
allow_empty
:
bool
=
False
,
)
->
List
[
Tuple
[
str
,
int
]]:
)
->
List
[
Tuple
[
str
,
int
]]:
"""Generates a list of samples of a form (path_to_sample, class).
"""Generates a list of samples of a form (path_to_sample, class).
...
@@ -172,6 +183,8 @@ class DatasetFolder(VisionDataset):
...
@@ -172,6 +183,8 @@ class DatasetFolder(VisionDataset):
and checks if the file is a valid file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
is_valid_file should not be passed. Defaults to None.
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Raises:
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``class_to_idx`` is empty.
...
@@ -186,7 +199,9 @@ class DatasetFolder(VisionDataset):
...
@@ -186,7 +199,9 @@ class DatasetFolder(VisionDataset):
# find_classes() function, instead of using that of the find_classes() method, which
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
# is potentially overridden and thus could have a different logic.
raise
ValueError
(
"The class_to_idx parameter cannot be None."
)
raise
ValueError
(
"The class_to_idx parameter cannot be None."
)
return
make_dataset
(
directory
,
class_to_idx
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
)
return
make_dataset
(
directory
,
class_to_idx
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
,
allow_empty
=
allow_empty
)
def
find_classes
(
self
,
directory
:
str
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
def
find_classes
(
self
,
directory
:
str
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
"""Find the class folders in a dataset structured as follows::
"""Find the class folders in a dataset structured as follows::
...
@@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder):
...
@@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder):
loader (callable, optional): A function to load an image given its path.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
and check if the file is a valid file (used to check of corrupt files)
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Attributes:
Attributes:
classes (list): List of the class names sorted alphabetically.
classes (list): List of the class names sorted alphabetically.
...
@@ -305,6 +322,7 @@ class ImageFolder(DatasetFolder):
...
@@ -305,6 +322,7 @@ class ImageFolder(DatasetFolder):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
allow_empty
:
bool
=
False
,
):
):
super
().
__init__
(
super
().
__init__
(
root
,
root
,
...
@@ -313,5 +331,6 @@ class ImageFolder(DatasetFolder):
...
@@ -313,5 +331,6 @@ class ImageFolder(DatasetFolder):
transform
=
transform
,
transform
=
transform
,
target_transform
=
target_transform
,
target_transform
=
target_transform
,
is_valid_file
=
is_valid_file
,
is_valid_file
=
is_valid_file
,
allow_empty
=
allow_empty
,
)
)
self
.
imgs
=
self
.
samples
self
.
imgs
=
self
.
samples
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment