Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
30397d91
Unverified
Commit
30397d91
authored
Dec 04, 2023
by
Philip Meier
Committed by
GitHub
Dec 04, 2023
Browse files
add Imagenette dataset (#8139)
parent
3feb5021
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
143 additions
and
0 deletions
+143
-0
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/test_datasets.py
test/test_datasets.py
+35
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-0
torchvision/datasets/imagenette.py
torchvision/datasets/imagenette.py
+104
-0
torchvision/tv_tensors/_dataset_wrapper.py
torchvision/tv_tensors/_dataset_wrapper.py
+1
-0
No files found.
docs/source/datasets.rst
View file @
30397d91
...
...
@@ -54,6 +54,7 @@ Image classification
GTSRB
INaturalist
ImageNet
Imagenette
KMNIST
LFWPeople
LSUN
...
...
test/test_datasets.py
View file @
30397d91
...
...
@@ -3377,6 +3377,41 @@ class Middlebury2014StereoTestCase(datasets_utils.ImageDatasetTestCase):
pass
class
ImagenetteTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Imagenette
ADDITIONAL_CONFIGS
=
combinations_grid
(
split
=
[
"train"
,
"val"
],
size
=
[
"full"
,
"320px"
,
"160px"
])
_WNIDS
=
[
"n01440764"
,
"n02102040"
,
"n02979186"
,
"n03000684"
,
"n03028079"
,
"n03394916"
,
"n03417042"
,
"n03425413"
,
"n03445777"
,
"n03888257"
,
]
def
inject_fake_data
(
self
,
tmpdir
,
config
):
archive_root
=
"imagenette2"
if
config
[
"size"
]
!=
"full"
:
archive_root
+=
f
"-
{
config
[
'size'
].
replace
(
'px'
,
''
)
}
"
image_root
=
pathlib
.
Path
(
tmpdir
)
/
archive_root
/
config
[
"split"
]
num_images_per_class
=
3
for
wnid
in
self
.
_WNIDS
:
datasets_utils
.
create_image_folder
(
root
=
image_root
,
name
=
wnid
,
file_name_fn
=
lambda
idx
:
f
"
{
wnid
}
_
{
idx
}
.JPEG"
,
num_examples
=
num_images_per_class
,
)
return
num_images_per_class
*
len
(
self
.
_WNIDS
)
class
TestDatasetWrapper
:
def
test_unknown_type
(
self
):
unknown_object
=
object
()
...
...
torchvision/datasets/__init__.py
View file @
30397d91
...
...
@@ -30,6 +30,7 @@ from .food101 import Food101
from
.gtsrb
import
GTSRB
from
.hmdb51
import
HMDB51
from
.imagenet
import
ImageNet
from
.imagenette
import
Imagenette
from
.inaturalist
import
INaturalist
from
.kinetics
import
Kinetics
from
.kitti
import
Kitti
...
...
@@ -128,6 +129,7 @@ __all__ = (
"InStereo2k"
,
"ETH3DStereo"
,
"wrap_dataset_for_transforms_v2"
,
"Imagenette"
,
)
...
...
torchvision/datasets/imagenette.py
0 → 100644
View file @
30397d91
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
PIL
import
Image
from
.folder
import
find_classes
,
make_dataset
from
.utils
import
download_and_extract_archive
,
verify_str_arg
from
.vision
import
VisionDataset
class
Imagenette
(
VisionDataset
):
"""`Imagenette <https://github.com/fastai/imagenette#imagenette-1>`_ image classification dataset.
Args:
root (string): Root directory of the Imagenette dataset.
split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
downloaded archives are not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version, e.g. ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
Attributes:
classes (list): List of the class name tuples.
class_to_idx (dict): Dict with items (class name, class index).
wnids (list): List of the WordNet IDs.
wnid_to_idx (dict): Dict with items (WordNet ID, class index).
"""
_ARCHIVES
=
{
"full"
:
(
"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz"
,
"fe2fc210e6bb7c5664d602c3cd71e612"
),
"320px"
:
(
"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
,
"3df6f0d01a2c9592104656642f5e78a3"
),
"160px"
:
(
"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
,
"e793b78cc4c9e9a4ccc0c1155377a412"
),
}
_WNID_TO_CLASS
=
{
"n01440764"
:
(
"tench"
,
"Tinca tinca"
),
"n02102040"
:
(
"English springer"
,
"English springer spaniel"
),
"n02979186"
:
(
"cassette player"
,),
"n03000684"
:
(
"chain saw"
,
"chainsaw"
),
"n03028079"
:
(
"church"
,
"church building"
),
"n03394916"
:
(
"French horn"
,
"horn"
),
"n03417042"
:
(
"garbage truck"
,
"dustcart"
),
"n03425413"
:
(
"gas pump"
,
"gasoline pump"
,
"petrol pump"
,
"island dispenser"
),
"n03445777"
:
(
"golf ball"
,),
"n03888257"
:
(
"parachute"
,
"chute"
),
}
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
size
:
str
=
"full"
,
download
=
False
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
_split
=
verify_str_arg
(
split
,
"split"
,
[
"train"
,
"val"
])
self
.
_size
=
verify_str_arg
(
size
,
"size"
,
[
"full"
,
"320px"
,
"160px"
])
self
.
_url
,
self
.
_md5
=
self
.
_ARCHIVES
[
self
.
_size
]
self
.
_size_root
=
Path
(
self
.
root
)
/
Path
(
self
.
_url
).
stem
self
.
_image_root
=
str
(
self
.
_size_root
/
self
.
_split
)
if
download
:
self
.
_download
()
elif
not
self
.
_check_exists
():
raise
RuntimeError
(
"Dataset not found. You can use download=True to download it."
)
self
.
wnids
,
self
.
wnid_to_idx
=
find_classes
(
self
.
_image_root
)
self
.
classes
=
[
self
.
_WNID_TO_CLASS
[
wnid
]
for
wnid
in
self
.
wnids
]
self
.
class_to_idx
=
{
class_name
:
idx
for
wnid
,
idx
in
self
.
wnid_to_idx
.
items
()
for
class_name
in
self
.
_WNID_TO_CLASS
[
wnid
]
}
self
.
_samples
=
make_dataset
(
self
.
_image_root
,
self
.
wnid_to_idx
,
extensions
=
".jpeg"
)
def
_check_exists
(
self
)
->
bool
:
return
self
.
_size_root
.
exists
()
def
_download
(
self
):
if
self
.
_check_exists
():
raise
RuntimeError
(
f
"The directory
{
self
.
_size_root
}
already exists. "
f
"If you want to re-download or re-extract the images, delete the directory."
)
download_and_extract_archive
(
self
.
_url
,
self
.
root
,
md5
=
self
.
_md5
)
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]:
path
,
label
=
self
.
_samples
[
idx
]
image
=
Image
.
open
(
path
).
convert
(
"RGB"
)
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
if
self
.
target_transform
is
not
None
:
label
=
self
.
target_transform
(
label
)
return
image
,
label
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_samples
)
torchvision/tv_tensors/_dataset_wrapper.py
View file @
30397d91
...
...
@@ -284,6 +284,7 @@ for dataset_cls in [
datasets
.
GTSRB
,
datasets
.
DatasetFolder
,
datasets
.
ImageFolder
,
datasets
.
Imagenette
,
]:
WRAPPER_FACTORIES
.
register
(
dataset_cls
)(
classification_wrapper_factory
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment