Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
48b1edff
Unverified
Commit
48b1edff
authored
Jun 14, 2024
by
Nicolas Hug
Committed by
GitHub
Jun 14, 2024
Browse files
Remove prototype area for 0.19 (#8491)
parent
f44f20cf
Changes
74
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3780 deletions
+0
-3780
torchvision/prototype/datasets/_builtin/cifar100.categories
torchvision/prototype/datasets/_builtin/cifar100.categories
+0
-100
torchvision/prototype/datasets/_builtin/clevr.py
torchvision/prototype/datasets/_builtin/clevr.py
+0
-107
torchvision/prototype/datasets/_builtin/coco.categories
torchvision/prototype/datasets/_builtin/coco.categories
+0
-91
torchvision/prototype/datasets/_builtin/coco.py
torchvision/prototype/datasets/_builtin/coco.py
+0
-274
torchvision/prototype/datasets/_builtin/country211.categories
...hvision/prototype/datasets/_builtin/country211.categories
+0
-211
torchvision/prototype/datasets/_builtin/country211.py
torchvision/prototype/datasets/_builtin/country211.py
+0
-81
torchvision/prototype/datasets/_builtin/cub200.categories
torchvision/prototype/datasets/_builtin/cub200.categories
+0
-200
torchvision/prototype/datasets/_builtin/cub200.py
torchvision/prototype/datasets/_builtin/cub200.py
+0
-265
torchvision/prototype/datasets/_builtin/dtd.categories
torchvision/prototype/datasets/_builtin/dtd.categories
+0
-47
torchvision/prototype/datasets/_builtin/dtd.py
torchvision/prototype/datasets/_builtin/dtd.py
+0
-139
torchvision/prototype/datasets/_builtin/eurosat.py
torchvision/prototype/datasets/_builtin/eurosat.py
+0
-66
torchvision/prototype/datasets/_builtin/fer2013.py
torchvision/prototype/datasets/_builtin/fer2013.py
+0
-64
torchvision/prototype/datasets/_builtin/food101.categories
torchvision/prototype/datasets/_builtin/food101.categories
+0
-101
torchvision/prototype/datasets/_builtin/food101.py
torchvision/prototype/datasets/_builtin/food101.py
+0
-97
torchvision/prototype/datasets/_builtin/gtsrb.py
torchvision/prototype/datasets/_builtin/gtsrb.py
+0
-112
torchvision/prototype/datasets/_builtin/imagenet.categories
torchvision/prototype/datasets/_builtin/imagenet.categories
+0
-1000
torchvision/prototype/datasets/_builtin/imagenet.py
torchvision/prototype/datasets/_builtin/imagenet.py
+0
-223
torchvision/prototype/datasets/_builtin/mnist.py
torchvision/prototype/datasets/_builtin/mnist.py
+0
-419
torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories
...on/prototype/datasets/_builtin/oxford-iiit-pet.categories
+0
-37
torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
+0
-146
No files found.
torchvision/prototype/datasets/_builtin/cifar100.categories
deleted
100644 → 0
View file @
f44f20cf
apple
aquarium_fish
baby
bear
beaver
bed
bee
beetle
bicycle
bottle
bowl
boy
bridge
bus
butterfly
camel
can
castle
caterpillar
cattle
chair
chimpanzee
clock
cloud
cockroach
couch
crab
crocodile
cup
dinosaur
dolphin
elephant
flatfish
forest
fox
girl
hamster
house
kangaroo
keyboard
lamp
lawn_mower
leopard
lion
lizard
lobster
man
maple_tree
motorcycle
mountain
mouse
mushroom
oak_tree
orange
orchid
otter
palm_tree
pear
pickup_truck
pine_tree
plain
plate
poppy
porcupine
possum
rabbit
raccoon
ray
road
rocket
rose
sea
seal
shark
shrew
skunk
skyscraper
snail
snake
spider
squirrel
streetcar
sunflower
sweet_pepper
table
tank
telephone
television
tiger
tractor
train
trout
tulip
turtle
wardrobe
whale
willow_tree
wolf
woman
worm
torchvision/prototype/datasets/_builtin/clevr.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
JsonParser
,
Mapper
,
UnBatcher
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
path_comparator
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"clevr"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
()
@
register_dataset
(
NAME
)
class
CLEVR
(
Dataset
):
"""
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"val"
,
"test"
))
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
archive
=
HttpResource
(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
,
sha256
=
"5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1"
,
)
return
[
archive
]
def
_classify_archive
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
path
=
pathlib
.
Path
(
data
[
0
])
if
path
.
parents
[
1
].
name
==
"images"
:
return
0
elif
path
.
parent
.
name
==
"scenes"
:
return
1
else
:
return
None
def
_filter_scene_anns
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
key
,
_
=
data
return
key
==
"scenes"
def
_add_empty_anns
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Tuple
[
Tuple
[
str
,
BinaryIO
],
None
]:
return
data
,
None
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
str
,
BinaryIO
],
Optional
[
Dict
[
str
,
Any
]]])
->
Dict
[
str
,
Any
]:
image_data
,
scenes_data
=
data
path
,
buffer
=
image_data
return
dict
(
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
label
=
Label
(
len
(
scenes_data
[
"objects"
]))
if
scenes_data
else
None
,
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
archive_dp
=
resource_dps
[
0
]
images_dp
,
scenes_dp
=
Demultiplexer
(
archive_dp
,
2
,
self
.
_classify_archive
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
images_dp
=
Filter
(
images_dp
,
path_comparator
(
"parent.name"
,
self
.
_split
))
images_dp
=
hint_shuffling
(
images_dp
)
images_dp
=
hint_sharding
(
images_dp
)
if
self
.
_split
!=
"test"
:
scenes_dp
=
Filter
(
scenes_dp
,
path_comparator
(
"name"
,
f
"CLEVR_
{
self
.
_split
}
_scenes.json"
))
scenes_dp
=
JsonParser
(
scenes_dp
)
scenes_dp
=
Mapper
(
scenes_dp
,
getitem
(
1
,
"scenes"
))
scenes_dp
=
UnBatcher
(
scenes_dp
)
dp
=
IterKeyZipper
(
images_dp
,
scenes_dp
,
key_fn
=
path_accessor
(
"name"
),
ref_key_fn
=
getitem
(
"image_filename"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
else
:
for
_
,
file
in
scenes_dp
:
file
.
close
()
dp
=
Mapper
(
images_dp
,
self
.
_add_empty_anns
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
70_000
if
self
.
_split
==
"train"
else
15_000
torchvision/prototype/datasets/_builtin/coco.categories
deleted
100644 → 0
View file @
f44f20cf
__background__,N/A
person,person
bicycle,vehicle
car,vehicle
motorcycle,vehicle
airplane,vehicle
bus,vehicle
train,vehicle
truck,vehicle
boat,vehicle
traffic light,outdoor
fire hydrant,outdoor
N/A,N/A
stop sign,outdoor
parking meter,outdoor
bench,outdoor
bird,animal
cat,animal
dog,animal
horse,animal
sheep,animal
cow,animal
elephant,animal
bear,animal
zebra,animal
giraffe,animal
N/A,N/A
backpack,accessory
umbrella,accessory
N/A,N/A
N/A,N/A
handbag,accessory
tie,accessory
suitcase,accessory
frisbee,sports
skis,sports
snowboard,sports
sports ball,sports
kite,sports
baseball bat,sports
baseball glove,sports
skateboard,sports
surfboard,sports
tennis racket,sports
bottle,kitchen
N/A,N/A
wine glass,kitchen
cup,kitchen
fork,kitchen
knife,kitchen
spoon,kitchen
bowl,kitchen
banana,food
apple,food
sandwich,food
orange,food
broccoli,food
carrot,food
hot dog,food
pizza,food
donut,food
cake,food
chair,furniture
couch,furniture
potted plant,furniture
bed,furniture
N/A,N/A
dining table,furniture
N/A,N/A
N/A,N/A
toilet,furniture
N/A,N/A
tv,electronic
laptop,electronic
mouse,electronic
remote,electronic
keyboard,electronic
cell phone,electronic
microwave,appliance
oven,appliance
toaster,appliance
sink,appliance
refrigerator,appliance
N/A,N/A
book,indoor
clock,indoor
vase,indoor
scissors,indoor
teddy bear,indoor
hair drier,indoor
toothbrush,indoor
torchvision/prototype/datasets/_builtin/coco.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
import
re
from
collections
import
defaultdict
,
OrderedDict
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torchdata.datapipes.iter
import
(
Demultiplexer
,
Filter
,
Grouper
,
IterDataPipe
,
IterKeyZipper
,
JsonParser
,
Mapper
,
UnBatcher
,
)
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
MappingIterator
,
path_accessor
,
read_categories_file
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
,
Mask
from
.._api
import
register_dataset
,
register_info
NAME
=
"coco"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
categories
,
super_categories
=
zip
(
*
read_categories_file
(
NAME
))
return
dict
(
categories
=
categories
,
super_categories
=
super_categories
)
@
register_dataset
(
NAME
)
class
Coco
(
Dataset
):
"""
- **homepage**: https://cocodataset.org/
- **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
year
:
str
=
"2017"
,
annotations
:
Optional
[
str
]
=
"instances"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"train"
,
"val"
})
self
.
_year
=
self
.
_verify_str_arg
(
year
,
"year"
,
{
"2017"
,
"2014"
})
self
.
_annotations
=
(
self
.
_verify_str_arg
(
annotations
,
"annotations"
,
self
.
_ANN_DECODERS
.
keys
())
if
annotations
is
not
None
else
None
)
info
=
_info
()
categories
,
super_categories
=
info
[
"categories"
],
info
[
"super_categories"
]
self
.
_categories
=
categories
self
.
_category_to_super_category
=
dict
(
zip
(
categories
,
super_categories
))
super
().
__init__
(
root
,
dependencies
=
(
"pycocotools"
,),
skip_integrity_check
=
skip_integrity_check
)
_IMAGE_URL_BASE
=
"http://images.cocodataset.org/zips"
_IMAGES_CHECKSUMS
=
{
(
"2014"
,
"train"
):
"ede4087e640bddba550e090eae701092534b554b42b05ac33f0300b984b31775"
,
(
"2014"
,
"val"
):
"fe9be816052049c34717e077d9e34aa60814a55679f804cd043e3cbee3b9fde0"
,
(
"2017"
,
"train"
):
"69a8bb58ea5f8f99d24875f21416de2e9ded3178e903f1f7603e283b9e06d929"
,
(
"2017"
,
"val"
):
"4f7e2ccb2866ec5041993c9cf2a952bbed69647b115d0f74da7ce8f4bef82f05"
,
}
_META_URL_BASE
=
"http://images.cocodataset.org/annotations"
_META_CHECKSUMS
=
{
"2014"
:
"031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009"
,
"2017"
:
"113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268"
,
}
def
_resources
(
self
)
->
List
[
OnlineResource
]:
images
=
HttpResource
(
f
"
{
self
.
_IMAGE_URL_BASE
}
/
{
self
.
_split
}{
self
.
_year
}
.zip"
,
sha256
=
self
.
_IMAGES_CHECKSUMS
[(
self
.
_year
,
self
.
_split
)],
)
meta
=
HttpResource
(
f
"
{
self
.
_META_URL_BASE
}
/annotations_trainval
{
self
.
_year
}
.zip"
,
sha256
=
self
.
_META_CHECKSUMS
[
self
.
_year
],
)
return
[
images
,
meta
]
def
_segmentation_to_mask
(
self
,
segmentation
:
Any
,
*
,
is_crowd
:
bool
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
from
pycocotools
import
mask
if
is_crowd
:
segmentation
=
mask
.
frPyObjects
(
segmentation
,
*
spatial_size
)
else
:
segmentation
=
mask
.
merge
(
mask
.
frPyObjects
(
segmentation
,
*
spatial_size
))
return
torch
.
from_numpy
(
mask
.
decode
(
segmentation
)).
to
(
torch
.
bool
)
def
_decode_instances_anns
(
self
,
anns
:
List
[
Dict
[
str
,
Any
]],
image_meta
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
spatial_size
=
(
image_meta
[
"height"
],
image_meta
[
"width"
])
labels
=
[
ann
[
"category_id"
]
for
ann
in
anns
]
return
dict
(
segmentations
=
Mask
(
torch
.
stack
(
[
self
.
_segmentation_to_mask
(
ann
[
"segmentation"
],
is_crowd
=
ann
[
"iscrowd"
],
spatial_size
=
spatial_size
)
for
ann
in
anns
]
)
),
areas
=
torch
.
as_tensor
([
ann
[
"area"
]
for
ann
in
anns
]),
crowds
=
torch
.
as_tensor
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
bounding_boxes
=
BoundingBoxes
(
[
ann
[
"bbox"
]
for
ann
in
anns
],
format
=
"xywh"
,
spatial_size
=
spatial_size
,
),
labels
=
Label
(
labels
,
categories
=
self
.
_categories
),
super_categories
=
[
self
.
_category_to_super_category
[
self
.
_categories
[
label
]]
for
label
in
labels
],
ann_ids
=
[
ann
[
"id"
]
for
ann
in
anns
],
)
def
_decode_captions_ann
(
self
,
anns
:
List
[
Dict
[
str
,
Any
]],
image_meta
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
return
dict
(
captions
=
[
ann
[
"caption"
]
for
ann
in
anns
],
ann_ids
=
[
ann
[
"id"
]
for
ann
in
anns
],
)
_ANN_DECODERS
=
OrderedDict
(
[
(
"instances"
,
_decode_instances_anns
),
(
"captions"
,
_decode_captions_ann
),
]
)
_META_FILE_PATTERN
=
re
.
compile
(
rf
"(?P<annotations>(
{
'|'
.
join
(
_ANN_DECODERS
.
keys
())
}
))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)
def
_filter_meta_files
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
match
=
self
.
_META_FILE_PATTERN
.
match
(
pathlib
.
Path
(
data
[
0
]).
name
)
return
bool
(
match
and
match
[
"split"
]
==
self
.
_split
and
match
[
"year"
]
==
self
.
_year
and
match
[
"annotations"
]
==
self
.
_annotations
)
def
_classify_meta
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
key
,
_
=
data
if
key
==
"images"
:
return
0
elif
key
==
"annotations"
:
return
1
else
:
return
None
def
_prepare_image
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Dict
[
str
,
Any
]:
path
,
buffer
=
data
return
dict
(
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
)
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
List
[
Dict
[
str
,
Any
]],
Dict
[
str
,
Any
]],
Tuple
[
str
,
BinaryIO
]],
)
->
Dict
[
str
,
Any
]:
ann_data
,
image_data
=
data
anns
,
image_meta
=
ann_data
sample
=
self
.
_prepare_image
(
image_data
)
# this method is only called if we have annotations
annotations
=
cast
(
str
,
self
.
_annotations
)
sample
.
update
(
self
.
_ANN_DECODERS
[
annotations
](
self
,
anns
,
image_meta
))
return
sample
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
images_dp
,
meta_dp
=
resource_dps
if
self
.
_annotations
is
None
:
dp
=
hint_shuffling
(
images_dp
)
dp
=
hint_sharding
(
dp
)
dp
=
hint_shuffling
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_image
)
meta_dp
=
Filter
(
meta_dp
,
self
.
_filter_meta_files
)
meta_dp
=
JsonParser
(
meta_dp
)
meta_dp
=
Mapper
(
meta_dp
,
getitem
(
1
))
meta_dp
:
IterDataPipe
[
Dict
[
str
,
Dict
[
str
,
Any
]]]
=
MappingIterator
(
meta_dp
)
images_meta_dp
,
anns_meta_dp
=
Demultiplexer
(
meta_dp
,
2
,
self
.
_classify_meta
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
images_meta_dp
=
Mapper
(
images_meta_dp
,
getitem
(
1
))
images_meta_dp
=
UnBatcher
(
images_meta_dp
)
anns_meta_dp
=
Mapper
(
anns_meta_dp
,
getitem
(
1
))
anns_meta_dp
=
UnBatcher
(
anns_meta_dp
)
anns_meta_dp
=
Grouper
(
anns_meta_dp
,
group_key_fn
=
getitem
(
"image_id"
),
buffer_size
=
INFINITE_BUFFER_SIZE
)
anns_meta_dp
=
hint_shuffling
(
anns_meta_dp
)
anns_meta_dp
=
hint_sharding
(
anns_meta_dp
)
anns_dp
=
IterKeyZipper
(
anns_meta_dp
,
images_meta_dp
,
key_fn
=
getitem
(
0
,
"image_id"
),
ref_key_fn
=
getitem
(
"id"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
dp
=
IterKeyZipper
(
anns_dp
,
images_dp
,
key_fn
=
getitem
(
1
,
"file_name"
),
ref_key_fn
=
path_accessor
(
"name"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
{
(
"train"
,
"2017"
):
defaultdict
(
lambda
:
118_287
,
instances
=
117_266
),
(
"train"
,
"2014"
):
defaultdict
(
lambda
:
82_783
,
instances
=
82_081
),
(
"val"
,
"2017"
):
defaultdict
(
lambda
:
5_000
,
instances
=
4_952
),
(
"val"
,
"2014"
):
defaultdict
(
lambda
:
40_504
,
instances
=
40_137
),
}[(
self
.
_split
,
self
.
_year
)][
self
.
_annotations
# type: ignore[index]
]
def
_generate_categories
(
self
)
->
Tuple
[
Tuple
[
str
,
str
]]:
self
.
_annotations
=
"instances"
resources
=
self
.
_resources
()
dp
=
resources
[
1
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
self
.
_filter_meta_files
)
dp
=
JsonParser
(
dp
)
_
,
meta
=
next
(
iter
(
dp
))
# List[Tuple[super_category, id, category]]
label_data
=
[
cast
(
Tuple
[
str
,
int
,
str
],
tuple
(
info
.
values
()))
for
info
in
meta
[
"categories"
]]
# COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the
# full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total
# number of categories is 90 rather than 91.
_
,
ids
,
_
=
zip
(
*
label_data
)
missing_ids
=
set
(
range
(
1
,
max
(
ids
)
+
1
))
-
set
(
ids
)
label_data
.
extend
([(
"N/A"
,
id
,
"N/A"
)
for
id
in
missing_ids
])
# We also add a background category to be used during segmentation.
label_data
.
append
((
"N/A"
,
0
,
"__background__"
))
super_categories
,
_
,
categories
=
zip
(
*
sorted
(
label_data
,
key
=
lambda
info
:
info
[
1
]))
return
cast
(
Tuple
[
Tuple
[
str
,
str
]],
tuple
(
zip
(
categories
,
super_categories
)))
torchvision/prototype/datasets/_builtin/country211.categories
deleted
100644 → 0
View file @
f44f20cf
AD
AE
AF
AG
AI
AL
AM
AO
AQ
AR
AT
AU
AW
AX
AZ
BA
BB
BD
BE
BF
BG
BH
BJ
BM
BN
BO
BQ
BR
BS
BT
BW
BY
BZ
CA
CD
CF
CH
CI
CK
CL
CM
CN
CO
CR
CU
CV
CW
CY
CZ
DE
DK
DM
DO
DZ
EC
EE
EG
ES
ET
FI
FJ
FK
FO
FR
GA
GB
GD
GE
GF
GG
GH
GI
GL
GM
GP
GR
GS
GT
GU
GY
HK
HN
HR
HT
HU
ID
IE
IL
IM
IN
IQ
IR
IS
IT
JE
JM
JO
JP
KE
KG
KH
KN
KP
KR
KW
KY
KZ
LA
LB
LC
LI
LK
LR
LT
LU
LV
LY
MA
MC
MD
ME
MF
MG
MK
ML
MM
MN
MO
MQ
MR
MT
MU
MV
MW
MX
MY
MZ
NA
NC
NG
NI
NL
NO
NP
NZ
OM
PA
PE
PF
PG
PH
PK
PL
PR
PS
PT
PW
PY
QA
RE
RO
RS
RU
RW
SA
SB
SC
SD
SE
SG
SH
SI
SJ
SK
SL
SM
SN
SO
SS
SV
SX
SY
SZ
TG
TH
TJ
TL
TM
TN
TO
TR
TT
TW
TZ
UA
UG
US
UY
UZ
VA
VE
VG
VI
VN
VU
WS
XK
YE
ZA
ZM
ZW
torchvision/prototype/datasets/_builtin/country211.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_shuffling
,
path_comparator
,
read_categories_file
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"country211"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
NAME
))
@
register_dataset
(
NAME
)
class
Country211
(
Dataset
):
"""
- **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"val"
,
"test"
))
self
.
_split_folder_name
=
"valid"
if
split
==
"val"
else
split
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
return
[
HttpResource
(
"https://openaipublic.azureedge.net/clip/data/country211.tgz"
,
sha256
=
"c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c"
,
)
]
def
_prepare_sample
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
path
,
buffer
=
data
category
=
pathlib
.
Path
(
path
).
parent
.
name
return
dict
(
label
=
Label
.
from_category
(
category
,
categories
=
self
.
_categories
),
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
)
def
_filter_split
(
self
,
data
:
Tuple
[
str
,
Any
],
*
,
split
:
str
)
->
bool
:
return
pathlib
.
Path
(
data
[
0
]).
parent
.
parent
.
name
==
split
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
dp
=
resource_dps
[
0
]
dp
=
Filter
(
dp
,
path_comparator
(
"parent.parent.name"
,
self
.
_split_folder_name
))
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
{
"train"
:
31_650
,
"val"
:
10_550
,
"test"
:
21_100
,
}[
self
.
_split
]
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
return
sorted
({
pathlib
.
Path
(
path
).
parent
.
name
for
path
,
_
in
dp
})
torchvision/prototype/datasets/_builtin/cub200.categories
deleted
100644 → 0
View file @
f44f20cf
Black_footed_Albatross
Laysan_Albatross
Sooty_Albatross
Groove_billed_Ani
Crested_Auklet
Least_Auklet
Parakeet_Auklet
Rhinoceros_Auklet
Brewer_Blackbird
Red_winged_Blackbird
Rusty_Blackbird
Yellow_headed_Blackbird
Bobolink
Indigo_Bunting
Lazuli_Bunting
Painted_Bunting
Cardinal
Spotted_Catbird
Gray_Catbird
Yellow_breasted_Chat
Eastern_Towhee
Chuck_will_Widow
Brandt_Cormorant
Red_faced_Cormorant
Pelagic_Cormorant
Bronzed_Cowbird
Shiny_Cowbird
Brown_Creeper
American_Crow
Fish_Crow
Black_billed_Cuckoo
Mangrove_Cuckoo
Yellow_billed_Cuckoo
Gray_crowned_Rosy_Finch
Purple_Finch
Northern_Flicker
Acadian_Flycatcher
Great_Crested_Flycatcher
Least_Flycatcher
Olive_sided_Flycatcher
Scissor_tailed_Flycatcher
Vermilion_Flycatcher
Yellow_bellied_Flycatcher
Frigatebird
Northern_Fulmar
Gadwall
American_Goldfinch
European_Goldfinch
Boat_tailed_Grackle
Eared_Grebe
Horned_Grebe
Pied_billed_Grebe
Western_Grebe
Blue_Grosbeak
Evening_Grosbeak
Pine_Grosbeak
Rose_breasted_Grosbeak
Pigeon_Guillemot
California_Gull
Glaucous_winged_Gull
Heermann_Gull
Herring_Gull
Ivory_Gull
Ring_billed_Gull
Slaty_backed_Gull
Western_Gull
Anna_Hummingbird
Ruby_throated_Hummingbird
Rufous_Hummingbird
Green_Violetear
Long_tailed_Jaeger
Pomarine_Jaeger
Blue_Jay
Florida_Jay
Green_Jay
Dark_eyed_Junco
Tropical_Kingbird
Gray_Kingbird
Belted_Kingfisher
Green_Kingfisher
Pied_Kingfisher
Ringed_Kingfisher
White_breasted_Kingfisher
Red_legged_Kittiwake
Horned_Lark
Pacific_Loon
Mallard
Western_Meadowlark
Hooded_Merganser
Red_breasted_Merganser
Mockingbird
Nighthawk
Clark_Nutcracker
White_breasted_Nuthatch
Baltimore_Oriole
Hooded_Oriole
Orchard_Oriole
Scott_Oriole
Ovenbird
Brown_Pelican
White_Pelican
Western_Wood_Pewee
Sayornis
American_Pipit
Whip_poor_Will
Horned_Puffin
Common_Raven
White_necked_Raven
American_Redstart
Geococcyx
Loggerhead_Shrike
Great_Grey_Shrike
Baird_Sparrow
Black_throated_Sparrow
Brewer_Sparrow
Chipping_Sparrow
Clay_colored_Sparrow
House_Sparrow
Field_Sparrow
Fox_Sparrow
Grasshopper_Sparrow
Harris_Sparrow
Henslow_Sparrow
Le_Conte_Sparrow
Lincoln_Sparrow
Nelson_Sharp_tailed_Sparrow
Savannah_Sparrow
Seaside_Sparrow
Song_Sparrow
Tree_Sparrow
Vesper_Sparrow
White_crowned_Sparrow
White_throated_Sparrow
Cape_Glossy_Starling
Bank_Swallow
Barn_Swallow
Cliff_Swallow
Tree_Swallow
Scarlet_Tanager
Summer_Tanager
Artic_Tern
Black_Tern
Caspian_Tern
Common_Tern
Elegant_Tern
Forsters_Tern
Least_Tern
Green_tailed_Towhee
Brown_Thrasher
Sage_Thrasher
Black_capped_Vireo
Blue_headed_Vireo
Philadelphia_Vireo
Red_eyed_Vireo
Warbling_Vireo
White_eyed_Vireo
Yellow_throated_Vireo
Bay_breasted_Warbler
Black_and_white_Warbler
Black_throated_Blue_Warbler
Blue_winged_Warbler
Canada_Warbler
Cape_May_Warbler
Cerulean_Warbler
Chestnut_sided_Warbler
Golden_winged_Warbler
Hooded_Warbler
Kentucky_Warbler
Magnolia_Warbler
Mourning_Warbler
Myrtle_Warbler
Nashville_Warbler
Orange_crowned_Warbler
Palm_Warbler
Pine_Warbler
Prairie_Warbler
Prothonotary_Warbler
Swainson_Warbler
Tennessee_Warbler
Wilson_Warbler
Worm_eating_Warbler
Yellow_Warbler
Northern_Waterthrush
Louisiana_Waterthrush
Bohemian_Waxwing
Cedar_Waxwing
American_Three_toed_Woodpecker
Pileated_Woodpecker
Red_bellied_Woodpecker
Red_cockaded_Woodpecker
Red_headed_Woodpecker
Downy_Woodpecker
Bewick_Wren
Cactus_Wren
Carolina_Wren
House_Wren
Marsh_Wren
Rock_Wren
Winter_Wren
Common_Yellowthroat
torchvision/prototype/datasets/_builtin/cub200.py
deleted
100644 → 0
View file @
f44f20cf
import
csv
import
functools
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torchdata.datapipes.iter
import
(
CSVDictParser
,
CSVParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
,
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
path_comparator
,
read_categories_file
,
read_mat
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
csv
.
register_dialect
(
"cub200"
,
delimiter
=
" "
)
NAME
=
"cub200"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
NAME
))
@
register_dataset
(
NAME
)
class
CUB200
(
Dataset
):
"""
- **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
year
:
str
=
"2011"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"test"
))
self
.
_year
=
self
.
_verify_str_arg
(
year
,
"year"
,
(
"2010"
,
"2011"
))
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
# TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
# dependencies=("scipy",),
skip_integrity_check
=
skip_integrity_check
,
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
if
self
.
_year
==
"2011"
:
archive
=
GDriveResource
(
"1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
,
file_name
=
"CUB_200_2011.tgz"
,
sha256
=
"0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081"
,
preprocess
=
"decompress"
,
)
segmentations
=
GDriveResource
(
"1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP"
,
file_name
=
"segmentations.tgz"
,
sha256
=
"dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f"
,
preprocess
=
"decompress"
,
)
return
[
archive
,
segmentations
]
else
:
# self._year == "2010"
split
=
GDriveResource
(
"1vZuZPqha0JjmwkdaS_XtYryE3Jf5Q1AC"
,
file_name
=
"lists.tgz"
,
sha256
=
"aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428"
,
preprocess
=
"decompress"
,
)
images
=
GDriveResource
(
"1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx"
,
file_name
=
"images.tgz"
,
sha256
=
"2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e"
,
preprocess
=
"decompress"
,
)
anns
=
GDriveResource
(
"16NsbTpMs5L6hT4hUJAmpW2u7wH326WTR"
,
file_name
=
"annotations.tgz"
,
sha256
=
"c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1"
,
preprocess
=
"decompress"
,
)
return
[
split
,
images
,
anns
]
def
_2011_classify_archive
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
path
=
pathlib
.
Path
(
data
[
0
])
if
path
.
parents
[
1
].
name
==
"images"
:
return
0
elif
path
.
name
==
"train_test_split.txt"
:
return
1
elif
path
.
name
==
"images.txt"
:
return
2
elif
path
.
name
==
"bounding_boxes.txt"
:
return
3
else
:
return
None
def
_2011_extract_file_name
(
self
,
rel_posix_path
:
str
)
->
str
:
return
rel_posix_path
.
rsplit
(
"/"
,
maxsplit
=
1
)[
1
]
def
_2011_filter_split
(
self
,
row
:
List
[
str
])
->
bool
:
_
,
split_id
=
row
return
{
"0"
:
"test"
,
"1"
:
"train"
,
}[
split_id
]
==
self
.
_split
def
_2011_segmentation_key
(
self
,
data
:
Tuple
[
str
,
Any
])
->
str
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
with_suffix
(
".jpg"
).
name
def
_2011_prepare_ann
(
self
,
data
:
Tuple
[
str
,
Tuple
[
List
[
str
],
Tuple
[
str
,
BinaryIO
]]],
spatial_size
:
Tuple
[
int
,
int
]
)
->
Dict
[
str
,
Any
]:
_
,
(
bounding_boxes_data
,
segmentation_data
)
=
data
segmentation_path
,
segmentation_buffer
=
segmentation_data
return
dict
(
bounding_boxes
=
BoundingBoxes
(
[
float
(
part
)
for
part
in
bounding_boxes_data
[
1
:]],
format
=
"xywh"
,
spatial_size
=
spatial_size
),
segmentation_path
=
segmentation_path
,
segmentation
=
EncodedImage
.
from_file
(
segmentation_buffer
),
)
def
_2010_split_key
(
self
,
data
:
str
)
->
str
:
return
data
.
rsplit
(
"/"
,
maxsplit
=
1
)[
1
]
def
_2010_anns_key
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Tuple
[
str
,
Tuple
[
str
,
BinaryIO
]]:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
with_suffix
(
".jpg"
).
name
,
data
def
_2010_prepare_ann
(
self
,
data
:
Tuple
[
str
,
Tuple
[
str
,
BinaryIO
]],
spatial_size
:
Tuple
[
int
,
int
]
)
->
Dict
[
str
,
Any
]:
_
,
(
path
,
buffer
)
=
data
content
=
read_mat
(
buffer
)
return
dict
(
ann_path
=
path
,
bounding_boxes
=
BoundingBoxes
(
[
int
(
content
[
"bbox"
][
coord
])
for
coord
in
(
"left"
,
"bottom"
,
"right"
,
"top"
)],
format
=
"xyxy"
,
spatial_size
=
spatial_size
,
),
segmentation
=
torch
.
as_tensor
(
content
[
"seg"
]),
)
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
str
,
Tuple
[
str
,
BinaryIO
]],
Any
],
*
,
prepare_ann_fn
:
Callable
[[
Any
,
Tuple
[
int
,
int
]],
Dict
[
str
,
Any
]],
)
->
Dict
[
str
,
Any
]:
data
,
anns_data
=
data
_
,
image_data
=
data
path
,
buffer
=
image_data
image
=
EncodedImage
.
from_file
(
buffer
)
return
dict
(
prepare_ann_fn
(
anns_data
,
image
.
spatial_size
),
image
=
image
,
label
=
Label
(
int
(
pathlib
.
Path
(
path
).
parent
.
name
.
rsplit
(
"."
,
1
)[
0
])
-
1
,
categories
=
self
.
_categories
,
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
prepare_ann_fn
:
Callable
if
self
.
_year
==
"2011"
:
archive_dp
,
segmentations_dp
=
resource_dps
images_dp
,
split_dp
,
image_files_dp
,
bounding_boxes_dp
=
Demultiplexer
(
archive_dp
,
4
,
self
.
_2011_classify_archive
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
image_files_dp
=
CSVParser
(
image_files_dp
,
dialect
=
"cub200"
)
image_files_dp
=
Mapper
(
image_files_dp
,
self
.
_2011_extract_file_name
,
input_col
=
1
)
image_files_map
=
IterToMapConverter
(
image_files_dp
)
split_dp
=
CSVParser
(
split_dp
,
dialect
=
"cub200"
)
split_dp
=
Filter
(
split_dp
,
self
.
_2011_filter_split
)
split_dp
=
Mapper
(
split_dp
,
getitem
(
0
))
split_dp
=
Mapper
(
split_dp
,
image_files_map
.
__getitem__
)
bounding_boxes_dp
=
CSVParser
(
bounding_boxes_dp
,
dialect
=
"cub200"
)
bounding_boxes_dp
=
Mapper
(
bounding_boxes_dp
,
image_files_map
.
__getitem__
,
input_col
=
0
)
anns_dp
=
IterKeyZipper
(
bounding_boxes_dp
,
segmentations_dp
,
key_fn
=
getitem
(
0
),
ref_key_fn
=
self
.
_2011_segmentation_key
,
keep_key
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
prepare_ann_fn
=
self
.
_2011_prepare_ann
else
:
# self._year == "2010"
split_dp
,
images_dp
,
anns_dp
=
resource_dps
split_dp
=
Filter
(
split_dp
,
path_comparator
(
"name"
,
f
"
{
self
.
_split
}
.txt"
))
split_dp
=
LineReader
(
split_dp
,
decode
=
True
,
return_path
=
False
)
split_dp
=
Mapper
(
split_dp
,
self
.
_2010_split_key
)
anns_dp
=
Mapper
(
anns_dp
,
self
.
_2010_anns_key
)
prepare_ann_fn
=
self
.
_2010_prepare_ann
split_dp
=
hint_shuffling
(
split_dp
)
split_dp
=
hint_sharding
(
split_dp
)
dp
=
IterKeyZipper
(
split_dp
,
images_dp
,
getitem
(),
path_accessor
(
"name"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
dp
=
IterKeyZipper
(
dp
,
anns_dp
,
getitem
(
0
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
Mapper
(
dp
,
functools
.
partial
(
self
.
_prepare_sample
,
prepare_ann_fn
=
prepare_ann_fn
))
def
__len__
(
self
)
->
int
:
return
{
(
"train"
,
"2010"
):
3_000
,
(
"test"
,
"2010"
):
3_033
,
(
"train"
,
"2011"
):
5_994
,
(
"test"
,
"2011"
):
5_794
,
}[(
self
.
_split
,
self
.
_year
)]
def
_generate_categories
(
self
)
->
List
[
str
]:
self
.
_year
=
"2011"
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
path_comparator
(
"name"
,
"classes.txt"
))
dp
=
CSVDictParser
(
dp
,
fieldnames
=
(
"label"
,
"category"
),
dialect
=
"cub200"
)
return
[
row
[
"category"
].
split
(
"."
)[
1
]
for
row
in
dp
]
torchvision/prototype/datasets/_builtin/dtd.categories
deleted
100644 → 0
View file @
f44f20cf
banded
blotchy
braided
bubbly
bumpy
chequered
cobwebbed
cracked
crosshatched
crystalline
dotted
fibrous
flecked
freckled
frilly
gauzy
grid
grooved
honeycombed
interlaced
knitted
lacelike
lined
marbled
matted
meshed
paisley
perforated
pitted
pleated
polka-dotted
porous
potholed
scaly
smeared
spiralled
sprinkled
stained
stratified
striped
studded
swirly
veined
waffled
woven
wrinkled
zigzagged
torchvision/prototype/datasets/_builtin/dtd.py
deleted
100644 → 0
View file @
f44f20cf
import
enum
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
CSVParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_comparator
,
read_categories_file
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"dtd"
class
DTDDemux
(
enum
.
IntEnum
):
SPLIT
=
0
JOINT_CATEGORIES
=
1
IMAGES
=
2
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
NAME
))
@
register_dataset
(
NAME
)
class
DTD
(
Dataset
):
"""DTD Dataset.
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
fold
:
int
=
1
,
skip_validation_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"train"
,
"val"
,
"test"
})
if
not
(
1
<=
fold
<=
10
):
raise
ValueError
(
f
"The fold parameter should be an integer in [1, 10]. Got
{
fold
}
"
)
self
.
_fold
=
fold
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_validation_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
archive
=
HttpResource
(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
,
sha256
=
"e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205"
,
preprocess
=
"decompress"
,
)
return
[
archive
]
def
_classify_archive
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
path
=
pathlib
.
Path
(
data
[
0
])
if
path
.
parent
.
name
==
"labels"
:
if
path
.
name
==
"labels_joint_anno.txt"
:
return
DTDDemux
.
JOINT_CATEGORIES
return
DTDDemux
.
SPLIT
elif
path
.
parents
[
1
].
name
==
"images"
:
return
DTDDemux
.
IMAGES
else
:
return
None
def
_image_key_fn
(
self
,
data
:
Tuple
[
str
,
Any
])
->
str
:
path
=
pathlib
.
Path
(
data
[
0
])
# The split files contain hardcoded posix paths for the images, e.g. banded/banded_0001.jpg
return
str
(
path
.
relative_to
(
path
.
parents
[
1
]).
as_posix
())
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
str
,
List
[
str
]],
Tuple
[
str
,
BinaryIO
]])
->
Dict
[
str
,
Any
]:
(
_
,
joint_categories_data
),
image_data
=
data
_
,
*
joint_categories
=
joint_categories_data
path
,
buffer
=
image_data
category
=
pathlib
.
Path
(
path
).
parent
.
name
return
dict
(
joint_categories
=
{
category
for
category
in
joint_categories
if
category
},
label
=
Label
.
from_category
(
category
,
categories
=
self
.
_categories
),
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
archive_dp
=
resource_dps
[
0
]
splits_dp
,
joint_categories_dp
,
images_dp
=
Demultiplexer
(
archive_dp
,
3
,
self
.
_classify_archive
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
splits_dp
=
Filter
(
splits_dp
,
path_comparator
(
"name"
,
f
"
{
self
.
_split
}{
self
.
_fold
}
.txt"
))
splits_dp
=
LineReader
(
splits_dp
,
decode
=
True
,
return_path
=
False
)
splits_dp
=
hint_shuffling
(
splits_dp
)
splits_dp
=
hint_sharding
(
splits_dp
)
joint_categories_dp
=
CSVParser
(
joint_categories_dp
,
delimiter
=
" "
)
dp
=
IterKeyZipper
(
splits_dp
,
joint_categories_dp
,
key_fn
=
getitem
(),
ref_key_fn
=
getitem
(
0
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
dp
=
IterKeyZipper
(
dp
,
images_dp
,
key_fn
=
getitem
(
0
),
ref_key_fn
=
self
.
_image_key_fn
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
_filter_images
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
return
self
.
_classify_archive
(
data
)
==
DTDDemux
.
IMAGES
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
self
.
_filter_images
)
return
sorted
({
pathlib
.
Path
(
path
).
parent
.
name
for
path
,
_
in
dp
})
def
__len__
(
self
)
->
int
:
return
1_880
# All splits have the same length
torchvision/prototype/datasets/_builtin/eurosat.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
IterDataPipe
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
hint_sharding
,
hint_shuffling
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"eurosat"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
(
"AnnualCrop"
,
"Forest"
,
"HerbaceousVegetation"
,
"Highway"
,
"Industrial"
,
"Pasture"
,
"PermanentCrop"
,
"Residential"
,
"River"
,
"SeaLake"
,
)
)
@
register_dataset
(
NAME
)
class
EuroSAT
(
Dataset
):
"""EuroSAT Dataset.
homepage="https://github.com/phelber/eurosat",
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
skip_integrity_check
:
bool
=
False
)
->
None
:
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
return
[
HttpResource
(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip"
,
sha256
=
"8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd"
,
)
]
def
_prepare_sample
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
path
,
buffer
=
data
category
=
pathlib
.
Path
(
path
).
parent
.
name
return
dict
(
label
=
Label
.
from_category
(
category
,
categories
=
self
.
_categories
),
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
dp
=
resource_dps
[
0
]
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
27_000
torchvision/prototype/datasets/_builtin/fer2013.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
from
typing
import
Any
,
Dict
,
List
,
Union
import
torch
from
torchdata.datapipes.iter
import
CSVDictParser
,
IterDataPipe
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
KaggleDownloadResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
hint_sharding
,
hint_shuffling
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
Image
from
.._api
import
register_dataset
,
register_info
NAME
=
"fer2013"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
(
"angry"
,
"disgust"
,
"fear"
,
"happy"
,
"sad"
,
"surprise"
,
"neutral"
))
@
register_dataset
(
NAME
)
class
FER2013
(
Dataset
):
"""FER 2013 Dataset
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"train"
,
"test"
})
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
_CHECKSUMS
=
{
"train"
:
"a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10"
,
"test"
:
"dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3"
,
}
def
_resources
(
self
)
->
List
[
OnlineResource
]:
archive
=
KaggleDownloadResource
(
"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
,
file_name
=
f
"
{
self
.
_split
}
.csv.zip"
,
sha256
=
self
.
_CHECKSUMS
[
self
.
_split
],
)
return
[
archive
]
def
_prepare_sample
(
self
,
data
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
label_id
=
data
.
get
(
"emotion"
)
return
dict
(
image
=
Image
(
torch
.
tensor
([
int
(
idx
)
for
idx
in
data
[
"pixels"
].
split
()],
dtype
=
torch
.
uint8
).
reshape
(
48
,
48
)),
label
=
Label
(
int
(
label_id
),
categories
=
self
.
_categories
)
if
label_id
is
not
None
else
None
,
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
dp
=
resource_dps
[
0
]
dp
=
CSVDictParser
(
dp
)
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
28_709
if
self
.
_split
==
"train"
else
3_589
torchvision/prototype/datasets/_builtin/food101.categories
deleted
100644 → 0
View file @
f44f20cf
apple_pie
baby_back_ribs
baklava
beef_carpaccio
beef_tartare
beet_salad
beignets
bibimbap
bread_pudding
breakfast_burrito
bruschetta
caesar_salad
cannoli
caprese_salad
carrot_cake
ceviche
cheesecake
cheese_plate
chicken_curry
chicken_quesadilla
chicken_wings
chocolate_cake
chocolate_mousse
churros
clam_chowder
club_sandwich
crab_cakes
creme_brulee
croque_madame
cup_cakes
deviled_eggs
donuts
dumplings
edamame
eggs_benedict
escargots
falafel
filet_mignon
fish_and_chips
foie_gras
french_fries
french_onion_soup
french_toast
fried_calamari
fried_rice
frozen_yogurt
garlic_bread
gnocchi
greek_salad
grilled_cheese_sandwich
grilled_salmon
guacamole
gyoza
hamburger
hot_and_sour_soup
hot_dog
huevos_rancheros
hummus
ice_cream
lasagna
lobster_bisque
lobster_roll_sandwich
macaroni_and_cheese
macarons
miso_soup
mussels
nachos
omelette
onion_rings
oysters
pad_thai
paella
pancakes
panna_cotta
peking_duck
pho
pizza
pork_chop
poutine
prime_rib
pulled_pork_sandwich
ramen
ravioli
red_velvet_cake
risotto
samosa
sashimi
scallops
seaweed_salad
shrimp_and_grits
spaghetti_bolognese
spaghetti_carbonara
spring_rolls
steak
strawberry_shortcake
sushi
tacos
takoyaki
tiramisu
tuna_tartare
waffles
torchvision/prototype/datasets/_builtin/food101.py
deleted
100644 → 0
View file @
f44f20cf
from
pathlib
import
Path
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_comparator
,
read_categories_file
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"food101"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
NAME
))
@
register_dataset
(
NAME
)
class
Food101
(
Dataset
):
"""Food 101 dataset
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
"""
def
__init__
(
self
,
root
:
Union
[
str
,
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"train"
,
"test"
})
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
return
[
HttpResource
(
url
=
"http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
,
sha256
=
"d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4"
,
preprocess
=
"decompress"
,
)
]
def
_classify_archive
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
path
=
Path
(
data
[
0
])
if
path
.
parents
[
1
].
name
==
"images"
:
return
0
elif
path
.
parents
[
0
].
name
==
"meta"
:
return
1
else
:
return
None
def
_prepare_sample
(
self
,
data
:
Tuple
[
str
,
Tuple
[
str
,
BinaryIO
]])
->
Dict
[
str
,
Any
]:
id
,
(
path
,
buffer
)
=
data
return
dict
(
label
=
Label
.
from_category
(
id
.
split
(
"/"
,
1
)[
0
],
categories
=
self
.
_categories
),
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
)
def
_image_key
(
self
,
data
:
Tuple
[
str
,
Any
])
->
str
:
path
=
Path
(
data
[
0
])
return
path
.
relative_to
(
path
.
parents
[
1
]).
with_suffix
(
""
).
as_posix
()
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
archive_dp
=
resource_dps
[
0
]
images_dp
,
split_dp
=
Demultiplexer
(
archive_dp
,
2
,
self
.
_classify_archive
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
split_dp
=
Filter
(
split_dp
,
path_comparator
(
"name"
,
f
"
{
self
.
_split
}
.txt"
))
split_dp
=
LineReader
(
split_dp
,
decode
=
True
,
return_path
=
False
)
split_dp
=
hint_sharding
(
split_dp
)
split_dp
=
hint_shuffling
(
split_dp
)
dp
=
IterKeyZipper
(
split_dp
,
images_dp
,
key_fn
=
getitem
(),
ref_key_fn
=
self
.
_image_key
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
path_comparator
(
"name"
,
"classes.txt"
))
dp
=
LineReader
(
dp
,
decode
=
True
,
return_path
=
False
)
return
list
(
dp
)
def
__len__
(
self
)
->
int
:
return
75_750
if
self
.
_split
==
"train"
else
25_250
torchvision/prototype/datasets/_builtin/gtsrb.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
CSVDictParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
Mapper
,
Zipper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_comparator
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
NAME
=
"gtsrb"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
[
f
"
{
label
:
05
d
}
"
for
label
in
range
(
43
)],
)
@
register_dataset
(
NAME
)
class
GTSRB
(
Dataset
):
"""GTSRB Dataset
homepage="https://benchmark.ini.rub.de"
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"train"
,
"test"
})
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
_URL_ROOT
=
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
_URLS
=
{
"train"
:
f
"
{
_URL_ROOT
}
GTSRB-Training_fixed.zip"
,
"test"
:
f
"
{
_URL_ROOT
}
GTSRB_Final_Test_Images.zip"
,
"test_ground_truth"
:
f
"
{
_URL_ROOT
}
GTSRB_Final_Test_GT.zip"
,
}
_CHECKSUMS
=
{
"train"
:
"df4144942083645bd60b594de348aa6930126c3e0e5de09e39611630abf8455a"
,
"test"
:
"48ba6fab7e877eb64eaf8de99035b0aaecfbc279bee23e35deca4ac1d0a837fa"
,
"test_ground_truth"
:
"f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d"
,
}
def
_resources
(
self
)
->
List
[
OnlineResource
]:
rsrcs
:
List
[
OnlineResource
]
=
[
HttpResource
(
self
.
_URLS
[
self
.
_split
],
sha256
=
self
.
_CHECKSUMS
[
self
.
_split
])]
if
self
.
_split
==
"test"
:
rsrcs
.
append
(
HttpResource
(
self
.
_URLS
[
"test_ground_truth"
],
sha256
=
self
.
_CHECKSUMS
[
"test_ground_truth"
],
)
)
return
rsrcs
def
_classify_train_archive
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
path
=
pathlib
.
Path
(
data
[
0
])
if
path
.
suffix
==
".ppm"
:
return
0
elif
path
.
suffix
==
".csv"
:
return
1
else
:
return
None
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
str
,
Any
],
Dict
[
str
,
Any
]])
->
Dict
[
str
,
Any
]:
(
path
,
buffer
),
csv_info
=
data
label
=
int
(
csv_info
[
"ClassId"
])
bounding_boxes
=
BoundingBoxes
(
[
int
(
csv_info
[
k
])
for
k
in
(
"Roi.X1"
,
"Roi.Y1"
,
"Roi.X2"
,
"Roi.Y2"
)],
format
=
"xyxy"
,
spatial_size
=
(
int
(
csv_info
[
"Height"
]),
int
(
csv_info
[
"Width"
])),
)
return
{
"path"
:
path
,
"image"
:
EncodedImage
.
from_file
(
buffer
),
"label"
:
Label
(
label
,
categories
=
self
.
_categories
),
"bounding_boxes"
:
bounding_boxes
,
}
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
if
self
.
_split
==
"train"
:
images_dp
,
ann_dp
=
Demultiplexer
(
resource_dps
[
0
],
2
,
self
.
_classify_train_archive
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
else
:
images_dp
,
ann_dp
=
resource_dps
images_dp
=
Filter
(
images_dp
,
path_comparator
(
"suffix"
,
".ppm"
))
# The order of the image files in the .zip archives perfectly match the order of the entries in the
# (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper.
ann_dp
=
CSVDictParser
(
ann_dp
,
delimiter
=
";"
)
dp
=
Zipper
(
images_dp
,
ann_dp
)
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
26_640
if
self
.
_split
==
"train"
else
12_630
torchvision/prototype/datasets/_builtin/imagenet.categories
deleted
100644 → 0
View file @
f44f20cf
tench,n01440764
goldfish,n01443537
great white shark,n01484850
tiger shark,n01491361
hammerhead,n01494475
electric ray,n01496331
stingray,n01498041
cock,n01514668
hen,n01514859
ostrich,n01518878
brambling,n01530575
goldfinch,n01531178
house finch,n01532829
junco,n01534433
indigo bunting,n01537544
robin,n01558993
bulbul,n01560419
jay,n01580077
magpie,n01582220
chickadee,n01592084
water ouzel,n01601694
kite,n01608432
bald eagle,n01614925
vulture,n01616318
great grey owl,n01622779
European fire salamander,n01629819
common newt,n01630670
eft,n01631663
spotted salamander,n01632458
axolotl,n01632777
bullfrog,n01641577
tree frog,n01644373
tailed frog,n01644900
loggerhead,n01664065
leatherback turtle,n01665541
mud turtle,n01667114
terrapin,n01667778
box turtle,n01669191
banded gecko,n01675722
common iguana,n01677366
American chameleon,n01682714
whiptail,n01685808
agama,n01687978
frilled lizard,n01688243
alligator lizard,n01689811
Gila monster,n01692333
green lizard,n01693334
African chameleon,n01694178
Komodo dragon,n01695060
African crocodile,n01697457
American alligator,n01698640
triceratops,n01704323
thunder snake,n01728572
ringneck snake,n01728920
hognose snake,n01729322
green snake,n01729977
king snake,n01734418
garter snake,n01735189
water snake,n01737021
vine snake,n01739381
night snake,n01740131
boa constrictor,n01742172
rock python,n01744401
Indian cobra,n01748264
green mamba,n01749939
sea snake,n01751748
horned viper,n01753488
diamondback,n01755581
sidewinder,n01756291
trilobite,n01768244
harvestman,n01770081
scorpion,n01770393
black and gold garden spider,n01773157
barn spider,n01773549
garden spider,n01773797
black widow,n01774384
tarantula,n01774750
wolf spider,n01775062
tick,n01776313
centipede,n01784675
black grouse,n01795545
ptarmigan,n01796340
ruffed grouse,n01797886
prairie chicken,n01798484
peacock,n01806143
quail,n01806567
partridge,n01807496
African grey,n01817953
macaw,n01818515
sulphur-crested cockatoo,n01819313
lorikeet,n01820546
coucal,n01824575
bee eater,n01828970
hornbill,n01829413
hummingbird,n01833805
jacamar,n01843065
toucan,n01843383
drake,n01847000
red-breasted merganser,n01855032
goose,n01855672
black swan,n01860187
tusker,n01871265
echidna,n01872401
platypus,n01873310
wallaby,n01877812
koala,n01882714
wombat,n01883070
jellyfish,n01910747
sea anemone,n01914609
brain coral,n01917289
flatworm,n01924916
nematode,n01930112
conch,n01943899
snail,n01944390
slug,n01945685
sea slug,n01950731
chiton,n01955084
chambered nautilus,n01968897
Dungeness crab,n01978287
rock crab,n01978455
fiddler crab,n01980166
king crab,n01981276
American lobster,n01983481
spiny lobster,n01984695
crayfish,n01985128
hermit crab,n01986214
isopod,n01990800
white stork,n02002556
black stork,n02002724
spoonbill,n02006656
flamingo,n02007558
little blue heron,n02009229
American egret,n02009912
bittern,n02011460
crane,n02012849
limpkin,n02013706
European gallinule,n02017213
American coot,n02018207
bustard,n02018795
ruddy turnstone,n02025239
red-backed sandpiper,n02027492
redshank,n02028035
dowitcher,n02033041
oystercatcher,n02037110
pelican,n02051845
king penguin,n02056570
albatross,n02058221
grey whale,n02066245
killer whale,n02071294
dugong,n02074367
sea lion,n02077923
Chihuahua,n02085620
Japanese spaniel,n02085782
Maltese dog,n02085936
Pekinese,n02086079
Shih-Tzu,n02086240
Blenheim spaniel,n02086646
papillon,n02086910
toy terrier,n02087046
Rhodesian ridgeback,n02087394
Afghan hound,n02088094
basset,n02088238
beagle,n02088364
bloodhound,n02088466
bluetick,n02088632
black-and-tan coonhound,n02089078
Walker hound,n02089867
English foxhound,n02089973
redbone,n02090379
borzoi,n02090622
Irish wolfhound,n02090721
Italian greyhound,n02091032
whippet,n02091134
Ibizan hound,n02091244
Norwegian elkhound,n02091467
otterhound,n02091635
Saluki,n02091831
Scottish deerhound,n02092002
Weimaraner,n02092339
Staffordshire bullterrier,n02093256
American Staffordshire terrier,n02093428
Bedlington terrier,n02093647
Border terrier,n02093754
Kerry blue terrier,n02093859
Irish terrier,n02093991
Norfolk terrier,n02094114
Norwich terrier,n02094258
Yorkshire terrier,n02094433
wire-haired fox terrier,n02095314
Lakeland terrier,n02095570
Sealyham terrier,n02095889
Airedale,n02096051
cairn,n02096177
Australian terrier,n02096294
Dandie Dinmont,n02096437
Boston bull,n02096585
miniature schnauzer,n02097047
giant schnauzer,n02097130
standard schnauzer,n02097209
Scotch terrier,n02097298
Tibetan terrier,n02097474
silky terrier,n02097658
soft-coated wheaten terrier,n02098105
West Highland white terrier,n02098286
Lhasa,n02098413
flat-coated retriever,n02099267
curly-coated retriever,n02099429
golden retriever,n02099601
Labrador retriever,n02099712
Chesapeake Bay retriever,n02099849
German short-haired pointer,n02100236
vizsla,n02100583
English setter,n02100735
Irish setter,n02100877
Gordon setter,n02101006
Brittany spaniel,n02101388
clumber,n02101556
English springer,n02102040
Welsh springer spaniel,n02102177
cocker spaniel,n02102318
Sussex spaniel,n02102480
Irish water spaniel,n02102973
kuvasz,n02104029
schipperke,n02104365
groenendael,n02105056
malinois,n02105162
briard,n02105251
kelpie,n02105412
komondor,n02105505
Old English sheepdog,n02105641
Shetland sheepdog,n02105855
collie,n02106030
Border collie,n02106166
Bouvier des Flandres,n02106382
Rottweiler,n02106550
German shepherd,n02106662
Doberman,n02107142
miniature pinscher,n02107312
Greater Swiss Mountain dog,n02107574
Bernese mountain dog,n02107683
Appenzeller,n02107908
EntleBucher,n02108000
boxer,n02108089
bull mastiff,n02108422
Tibetan mastiff,n02108551
French bulldog,n02108915
Great Dane,n02109047
Saint Bernard,n02109525
Eskimo dog,n02109961
malamute,n02110063
Siberian husky,n02110185
dalmatian,n02110341
affenpinscher,n02110627
basenji,n02110806
pug,n02110958
Leonberg,n02111129
Newfoundland,n02111277
Great Pyrenees,n02111500
Samoyed,n02111889
Pomeranian,n02112018
chow,n02112137
keeshond,n02112350
Brabancon griffon,n02112706
Pembroke,n02113023
Cardigan,n02113186
toy poodle,n02113624
miniature poodle,n02113712
standard poodle,n02113799
Mexican hairless,n02113978
timber wolf,n02114367
white wolf,n02114548
red wolf,n02114712
coyote,n02114855
dingo,n02115641
dhole,n02115913
African hunting dog,n02116738
hyena,n02117135
red fox,n02119022
kit fox,n02119789
Arctic fox,n02120079
grey fox,n02120505
tabby,n02123045
tiger cat,n02123159
Persian cat,n02123394
Siamese cat,n02123597
Egyptian cat,n02124075
cougar,n02125311
lynx,n02127052
leopard,n02128385
snow leopard,n02128757
jaguar,n02128925
lion,n02129165
tiger,n02129604
cheetah,n02130308
brown bear,n02132136
American black bear,n02133161
ice bear,n02134084
sloth bear,n02134418
mongoose,n02137549
meerkat,n02138441
tiger beetle,n02165105
ladybug,n02165456
ground beetle,n02167151
long-horned beetle,n02168699
leaf beetle,n02169497
dung beetle,n02172182
rhinoceros beetle,n02174001
weevil,n02177972
fly,n02190166
bee,n02206856
ant,n02219486
grasshopper,n02226429
cricket,n02229544
walking stick,n02231487
cockroach,n02233338
mantis,n02236044
cicada,n02256656
leafhopper,n02259212
lacewing,n02264363
dragonfly,n02268443
damselfly,n02268853
admiral,n02276258
ringlet,n02277742
monarch,n02279972
cabbage butterfly,n02280649
sulphur butterfly,n02281406
lycaenid,n02281787
starfish,n02317335
sea urchin,n02319095
sea cucumber,n02321529
wood rabbit,n02325366
hare,n02326432
Angora,n02328150
hamster,n02342885
porcupine,n02346627
fox squirrel,n02356798
marmot,n02361337
beaver,n02363005
guinea pig,n02364673
sorrel,n02389026
zebra,n02391049
hog,n02395406
wild boar,n02396427
warthog,n02397096
hippopotamus,n02398521
ox,n02403003
water buffalo,n02408429
bison,n02410509
ram,n02412080
bighorn,n02415577
ibex,n02417914
hartebeest,n02422106
impala,n02422699
gazelle,n02423022
Arabian camel,n02437312
llama,n02437616
weasel,n02441942
mink,n02442845
polecat,n02443114
black-footed ferret,n02443484
otter,n02444819
skunk,n02445715
badger,n02447366
armadillo,n02454379
three-toed sloth,n02457408
orangutan,n02480495
gorilla,n02480855
chimpanzee,n02481823
gibbon,n02483362
siamang,n02483708
guenon,n02484975
patas,n02486261
baboon,n02486410
macaque,n02487347
langur,n02488291
colobus,n02488702
proboscis monkey,n02489166
marmoset,n02490219
capuchin,n02492035
howler monkey,n02492660
titi,n02493509
spider monkey,n02493793
squirrel monkey,n02494079
Madagascar cat,n02497673
indri,n02500267
Indian elephant,n02504013
African elephant,n02504458
lesser panda,n02509815
giant panda,n02510455
barracouta,n02514041
eel,n02526121
coho,n02536864
rock beauty,n02606052
anemone fish,n02607072
sturgeon,n02640242
gar,n02641379
lionfish,n02643566
puffer,n02655020
abacus,n02666196
abaya,n02667093
academic gown,n02669723
accordion,n02672831
acoustic guitar,n02676566
aircraft carrier,n02687172
airliner,n02690373
airship,n02692877
altar,n02699494
ambulance,n02701002
amphibian,n02704792
analog clock,n02708093
apiary,n02727426
apron,n02730930
ashcan,n02747177
assault rifle,n02749479
backpack,n02769748
bakery,n02776631
balance beam,n02777292
balloon,n02782093
ballpoint,n02783161
Band Aid,n02786058
banjo,n02787622
bannister,n02788148
barbell,n02790996
barber chair,n02791124
barbershop,n02791270
barn,n02793495
barometer,n02794156
barrel,n02795169
barrow,n02797295
baseball,n02799071
basketball,n02802426
bassinet,n02804414
bassoon,n02804610
bathing cap,n02807133
bath towel,n02808304
bathtub,n02808440
beach wagon,n02814533
beacon,n02814860
beaker,n02815834
bearskin,n02817516
beer bottle,n02823428
beer glass,n02823750
bell cote,n02825657
bib,n02834397
bicycle-built-for-two,n02835271
bikini,n02837789
binder,n02840245
binoculars,n02841315
birdhouse,n02843684
boathouse,n02859443
bobsled,n02860847
bolo tie,n02865351
bonnet,n02869837
bookcase,n02870880
bookshop,n02871525
bottlecap,n02877765
bow,n02879718
bow tie,n02883205
brass,n02892201
brassiere,n02892767
breakwater,n02894605
breastplate,n02895154
broom,n02906734
bucket,n02909870
buckle,n02910353
bulletproof vest,n02916936
bullet train,n02917067
butcher shop,n02927161
cab,n02930766
caldron,n02939185
candle,n02948072
cannon,n02950826
canoe,n02951358
can opener,n02951585
cardigan,n02963159
car mirror,n02965783
carousel,n02966193
carpenter's kit,n02966687
carton,n02971356
car wheel,n02974003
cash machine,n02977058
cassette,n02978881
cassette player,n02979186
castle,n02980441
catamaran,n02981792
CD player,n02988304
cello,n02992211
cellular telephone,n02992529
chain,n02999410
chainlink fence,n03000134
chain mail,n03000247
chain saw,n03000684
chest,n03014705
chiffonier,n03016953
chime,n03017168
china cabinet,n03018349
Christmas stocking,n03026506
church,n03028079
cinema,n03032252
cleaver,n03041632
cliff dwelling,n03042490
cloak,n03045698
clog,n03047690
cocktail shaker,n03062245
coffee mug,n03063599
coffeepot,n03063689
coil,n03065424
combination lock,n03075370
computer keyboard,n03085013
confectionery,n03089624
container ship,n03095699
convertible,n03100240
corkscrew,n03109150
cornet,n03110669
cowboy boot,n03124043
cowboy hat,n03124170
cradle,n03125729
construction crane,n03126707
crash helmet,n03127747
crate,n03127925
crib,n03131574
Crock Pot,n03133878
croquet ball,n03134739
crutch,n03141823
cuirass,n03146219
dam,n03160309
desk,n03179701
desktop computer,n03180011
dial telephone,n03187595
diaper,n03188531
digital clock,n03196217
digital watch,n03197337
dining table,n03201208
dishrag,n03207743
dishwasher,n03207941
disk brake,n03208938
dock,n03216828
dogsled,n03218198
dome,n03220513
doormat,n03223299
drilling platform,n03240683
drum,n03249569
drumstick,n03250847
dumbbell,n03255030
Dutch oven,n03259280
electric fan,n03271574
electric guitar,n03272010
electric locomotive,n03272562
entertainment center,n03290653
envelope,n03291819
espresso maker,n03297495
face powder,n03314780
feather boa,n03325584
file,n03337140
fireboat,n03344393
fire engine,n03345487
fire screen,n03347037
flagpole,n03355925
flute,n03372029
folding chair,n03376595
football helmet,n03379051
forklift,n03384352
fountain,n03388043
fountain pen,n03388183
four-poster,n03388549
freight car,n03393912
French horn,n03394916
frying pan,n03400231
fur coat,n03404251
garbage truck,n03417042
gasmask,n03424325
gas pump,n03425413
goblet,n03443371
go-kart,n03444034
golf ball,n03445777
golfcart,n03445924
gondola,n03447447
gong,n03447721
gown,n03450230
grand piano,n03452741
greenhouse,n03457902
grille,n03459775
grocery store,n03461385
guillotine,n03467068
hair slide,n03476684
hair spray,n03476991
half track,n03478589
hammer,n03481172
hamper,n03482405
hand blower,n03483316
hand-held computer,n03485407
handkerchief,n03485794
hard disc,n03492542
harmonica,n03494278
harp,n03495258
harvester,n03496892
hatchet,n03498962
holster,n03527444
home theater,n03529860
honeycomb,n03530642
hook,n03532672
hoopskirt,n03534580
horizontal bar,n03535780
horse cart,n03538406
hourglass,n03544143
iPod,n03584254
iron,n03584829
jack-o'-lantern,n03590841
jean,n03594734
jeep,n03594945
jersey,n03595614
jigsaw puzzle,n03598930
jinrikisha,n03599486
joystick,n03602883
kimono,n03617480
knee pad,n03623198
knot,n03627232
lab coat,n03630383
ladle,n03633091
lampshade,n03637318
laptop,n03642806
lawn mower,n03649909
lens cap,n03657121
letter opener,n03658185
library,n03661043
lifeboat,n03662601
lighter,n03666591
limousine,n03670208
liner,n03673027
lipstick,n03676483
Loafer,n03680355
lotion,n03690938
loudspeaker,n03691459
loupe,n03692522
lumbermill,n03697007
magnetic compass,n03706229
mailbag,n03709823
mailbox,n03710193
maillot,n03710637
tank suit,n03710721
manhole cover,n03717622
maraca,n03720891
marimba,n03721384
mask,n03724870
matchstick,n03729826
maypole,n03733131
maze,n03733281
measuring cup,n03733805
medicine chest,n03742115
megalith,n03743016
microphone,n03759954
microwave,n03761084
military uniform,n03763968
milk can,n03764736
minibus,n03769881
miniskirt,n03770439
minivan,n03770679
missile,n03773504
mitten,n03775071
mixing bowl,n03775546
mobile home,n03776460
Model T,n03777568
modem,n03777754
monastery,n03781244
monitor,n03782006
moped,n03785016
mortar,n03786901
mortarboard,n03787032
mosque,n03788195
mosquito net,n03788365
motor scooter,n03791053
mountain bike,n03792782
mountain tent,n03792972
mouse,n03793489
mousetrap,n03794056
moving van,n03796401
muzzle,n03803284
nail,n03804744
neck brace,n03814639
necklace,n03814906
nipple,n03825788
notebook,n03832673
obelisk,n03837869
oboe,n03838899
ocarina,n03840681
odometer,n03841143
oil filter,n03843555
organ,n03854065
oscilloscope,n03857828
overskirt,n03866082
oxcart,n03868242
oxygen mask,n03868863
packet,n03871628
paddle,n03873416
paddlewheel,n03874293
padlock,n03874599
paintbrush,n03876231
pajama,n03877472
palace,n03877845
panpipe,n03884397
paper towel,n03887697
parachute,n03888257
parallel bars,n03888605
park bench,n03891251
parking meter,n03891332
passenger car,n03895866
patio,n03899768
pay-phone,n03902125
pedestal,n03903868
pencil box,n03908618
pencil sharpener,n03908714
perfume,n03916031
Petri dish,n03920288
photocopier,n03924679
pick,n03929660
pickelhaube,n03929855
picket fence,n03930313
pickup,n03930630
pier,n03933933
piggy bank,n03935335
pill bottle,n03937543
pillow,n03938244
ping-pong ball,n03942813
pinwheel,n03944341
pirate,n03947888
pitcher,n03950228
plane,n03954731
planetarium,n03956157
plastic bag,n03958227
plate rack,n03961711
plow,n03967562
plunger,n03970156
Polaroid camera,n03976467
pole,n03976657
police van,n03977966
poncho,n03980874
pool table,n03982430
pop bottle,n03983396
pot,n03991062
potter's wheel,n03992509
power drill,n03995372
prayer rug,n03998194
printer,n04004767
prison,n04005630
projectile,n04008634
projector,n04009552
puck,n04019541
punching bag,n04023962
purse,n04026417
quill,n04033901
quilt,n04033995
racer,n04037443
racket,n04039381
radiator,n04040759
radio,n04041544
radio telescope,n04044716
rain barrel,n04049303
recreational vehicle,n04065272
reel,n04067472
reflex camera,n04069434
refrigerator,n04070727
remote control,n04074963
restaurant,n04081281
revolver,n04086273
rifle,n04090263
rocking chair,n04099969
rotisserie,n04111531
rubber eraser,n04116512
rugby ball,n04118538
rule,n04118776
running shoe,n04120489
safe,n04125021
safety pin,n04127249
saltshaker,n04131690
sandal,n04133789
sarong,n04136333
sax,n04141076
scabbard,n04141327
scale,n04141975
school bus,n04146614
schooner,n04147183
scoreboard,n04149813
screen,n04152593
screw,n04153751
screwdriver,n04154565
seat belt,n04162706
sewing machine,n04179913
shield,n04192698
shoe shop,n04200800
shoji,n04201297
shopping basket,n04204238
shopping cart,n04204347
shovel,n04208210
shower cap,n04209133
shower curtain,n04209239
ski,n04228054
ski mask,n04229816
sleeping bag,n04235860
slide rule,n04238763
sliding door,n04239074
slot,n04243546
snorkel,n04251144
snowmobile,n04252077
snowplow,n04252225
soap dispenser,n04254120
soccer ball,n04254680
sock,n04254777
solar dish,n04258138
sombrero,n04259630
soup bowl,n04263257
space bar,n04264628
space heater,n04265275
space shuttle,n04266014
spatula,n04270147
speedboat,n04273569
spider web,n04275548
spindle,n04277352
sports car,n04285008
spotlight,n04286575
stage,n04296562
steam locomotive,n04310018
steel arch bridge,n04311004
steel drum,n04311174
stethoscope,n04317175
stole,n04325704
stone wall,n04326547
stopwatch,n04328186
stove,n04330267
strainer,n04332243
streetcar,n04335435
stretcher,n04336792
studio couch,n04344873
stupa,n04346328
submarine,n04347754
suit,n04350905
sundial,n04355338
sunglass,n04355933
sunglasses,n04356056
sunscreen,n04357314
suspension bridge,n04366367
swab,n04367480
sweatshirt,n04370456
swimming trunks,n04371430
swing,n04371774
switch,n04372370
syringe,n04376876
table lamp,n04380533
tank,n04389033
tape player,n04392985
teapot,n04398044
teddy,n04399382
television,n04404412
tennis ball,n04409515
thatch,n04417672
theater curtain,n04418357
thimble,n04423845
thresher,n04428191
throne,n04429376
tile roof,n04435653
toaster,n04442312
tobacco shop,n04443257
toilet seat,n04447861
torch,n04456115
totem pole,n04458633
tow truck,n04461696
toyshop,n04462240
tractor,n04465501
trailer truck,n04467665
tray,n04476259
trench coat,n04479046
tricycle,n04482393
trimaran,n04483307
tripod,n04485082
triumphal arch,n04486054
trolleybus,n04487081
trombone,n04487394
tub,n04493381
turnstile,n04501370
typewriter keyboard,n04505470
umbrella,n04507155
unicycle,n04509417
upright,n04515003
vacuum,n04517823
vase,n04522168
vault,n04523525
velvet,n04525038
vending machine,n04525305
vestment,n04532106
viaduct,n04532670
violin,n04536866
volleyball,n04540053
waffle iron,n04542943
wall clock,n04548280
wallet,n04548362
wardrobe,n04550184
warplane,n04552348
washbasin,n04553703
washer,n04554684
water bottle,n04557648
water jug,n04560804
water tower,n04562935
whiskey jug,n04579145
whistle,n04579432
wig,n04584207
window screen,n04589890
window shade,n04590129
Windsor tie,n04591157
wine bottle,n04591713
wing,n04592741
wok,n04596742
wooden spoon,n04597913
wool,n04599235
worm fence,n04604644
wreck,n04606251
yawl,n04612504
yurt,n04613696
web site,n06359193
comic book,n06596364
crossword puzzle,n06785654
street sign,n06794110
traffic light,n06874185
book jacket,n07248320
menu,n07565083
plate,n07579787
guacamole,n07583066
consomme,n07584110
hot pot,n07590611
trifle,n07613480
ice cream,n07614500
ice lolly,n07615774
French loaf,n07684084
bagel,n07693725
pretzel,n07695742
cheeseburger,n07697313
hotdog,n07697537
mashed potato,n07711569
head cabbage,n07714571
broccoli,n07714990
cauliflower,n07715103
zucchini,n07716358
spaghetti squash,n07716906
acorn squash,n07717410
butternut squash,n07717556
cucumber,n07718472
artichoke,n07718747
bell pepper,n07720875
cardoon,n07730033
mushroom,n07734744
Granny Smith,n07742313
strawberry,n07745940
orange,n07747607
lemon,n07749582
fig,n07753113
pineapple,n07753275
banana,n07753592
jackfruit,n07754684
custard apple,n07760859
pomegranate,n07768694
hay,n07802026
carbonara,n07831146
chocolate sauce,n07836838
dough,n07860988
meat loaf,n07871810
pizza,n07873807
potpie,n07875152
burrito,n07880968
red wine,n07892512
espresso,n07920052
cup,n07930864
eggnog,n07932039
alp,n09193705
bubble,n09229709
cliff,n09246464
coral reef,n09256479
geyser,n09288635
lakeside,n09332890
promontory,n09399592
sandbar,n09421951
seashore,n09428293
valley,n09468604
volcano,n09472597
ballplayer,n09835506
groom,n10148035
scuba diver,n10565667
rapeseed,n11879895
daisy,n11939491
yellow lady's slipper,n12057211
corn,n12144580
acorn,n12267677
hip,n12620546
buckeye,n12768682
coral fungus,n12985857
agaric,n12998815
gyromitra,n13037406
stinkhorn,n13040303
earthstar,n13044778
hen-of-the-woods,n13052670
bolete,n13054560
ear,n13133613
toilet tissue,n15075141
torchvision/prototype/datasets/_builtin/imagenet.py
deleted
100644 → 0
View file @
f44f20cf
import
enum
import
pathlib
import
re
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
Iterator
,
List
,
Match
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
(
Demultiplexer
,
Enumerator
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
,
TarArchiveLoader
,
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
ManualDownloadResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
read_categories_file
,
read_mat
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"imagenet"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
categories
,
wnids
=
zip
(
*
read_categories_file
(
NAME
))
return
dict
(
categories
=
categories
,
wnids
=
wnids
)
class
ImageNetResource
(
ManualDownloadResource
):
def
__init__
(
self
,
**
kwargs
:
Any
)
->
None
:
super
().
__init__
(
"Register on https://image-net.org/ and follow the instructions there."
,
**
kwargs
)
class
ImageNetDemux
(
enum
.
IntEnum
):
META
=
0
LABEL
=
1
class
CategoryAndWordNetIDExtractor
(
IterDataPipe
):
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP
=
{
"n03126707"
:
"construction crane"
,
"n03710721"
:
"tank suit"
,
}
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Tuple
[
str
,
BinaryIO
]])
->
None
:
self
.
datapipe
=
datapipe
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
str
,
str
]]:
for
_
,
stream
in
self
.
datapipe
:
synsets
=
read_mat
(
stream
,
squeeze_me
=
True
)[
"synsets"
]
for
_
,
wnid
,
category
,
_
,
num_children
,
*
_
in
synsets
:
if
num_children
>
0
:
# we are looking at a superclass that has no direct instance
continue
yield
self
.
_WNID_MAP
.
get
(
wnid
,
category
.
split
(
","
,
1
)[
0
]),
wnid
@
register_dataset
(
NAME
)
class
ImageNet
(
Dataset
):
"""
- **homepage**: https://www.image-net.org/
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"train"
,
"val"
,
"test"
})
info
=
_info
()
categories
,
wnids
=
info
[
"categories"
],
info
[
"wnids"
]
self
.
_categories
=
categories
self
.
_wnids
=
wnids
self
.
_wnid_to_category
=
dict
(
zip
(
wnids
,
categories
))
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
_IMAGES_CHECKSUMS
=
{
"train"
:
"b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb"
,
"val"
:
"c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0"
,
"test_v10102019"
:
"9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4"
,
}
def
_resources
(
self
)
->
List
[
OnlineResource
]:
name
=
"test_v10102019"
if
self
.
_split
==
"test"
else
self
.
_split
images
=
ImageNetResource
(
file_name
=
f
"ILSVRC2012_img_
{
name
}
.tar"
,
sha256
=
self
.
_IMAGES_CHECKSUMS
[
name
],
)
resources
:
List
[
OnlineResource
]
=
[
images
]
if
self
.
_split
==
"val"
:
devkit
=
ImageNetResource
(
file_name
=
"ILSVRC2012_devkit_t12.tar.gz"
,
sha256
=
"b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953"
,
)
resources
.
append
(
devkit
)
return
resources
_TRAIN_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"(?P<wnid>n\d{8})_\d+[.]JPEG"
)
def
_prepare_train_data
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Tuple
[
Tuple
[
Label
,
str
],
Tuple
[
str
,
BinaryIO
]]:
path
=
pathlib
.
Path
(
data
[
0
])
wnid
=
cast
(
Match
[
str
],
self
.
_TRAIN_IMAGE_NAME_PATTERN
.
match
(
path
.
name
))[
"wnid"
]
label
=
Label
.
from_category
(
self
.
_wnid_to_category
[
wnid
],
categories
=
self
.
_categories
)
return
(
label
,
wnid
),
data
def
_prepare_test_data
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Tuple
[
None
,
Tuple
[
str
,
BinaryIO
]]:
return
None
,
data
def
_classifiy_devkit
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Optional
[
int
]:
return
{
"meta.mat"
:
ImageNetDemux
.
META
,
"ILSVRC2012_validation_ground_truth.txt"
:
ImageNetDemux
.
LABEL
,
}.
get
(
pathlib
.
Path
(
data
[
0
]).
name
)
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
def
_val_test_image_key
(
self
,
path
:
pathlib
.
Path
)
->
int
:
return
int
(
self
.
_VAL_TEST_IMAGE_NAME_PATTERN
.
match
(
path
.
name
)[
"id"
])
# type: ignore[index]
def
_prepare_val_data
(
self
,
data
:
Tuple
[
Tuple
[
int
,
str
],
Tuple
[
str
,
BinaryIO
]]
)
->
Tuple
[
Tuple
[
Label
,
str
],
Tuple
[
str
,
BinaryIO
]]:
label_data
,
image_data
=
data
_
,
wnid
=
label_data
label
=
Label
.
from_category
(
self
.
_wnid_to_category
[
wnid
],
categories
=
self
.
_categories
)
return
(
label
,
wnid
),
image_data
def
_prepare_sample
(
self
,
data
:
Tuple
[
Optional
[
Tuple
[
Label
,
str
]],
Tuple
[
str
,
BinaryIO
]],
)
->
Dict
[
str
,
Any
]:
label_data
,
(
path
,
buffer
)
=
data
return
dict
(
dict
(
zip
((
"label"
,
"wnid"
),
label_data
if
label_data
else
(
None
,
None
))),
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
if
self
.
_split
in
{
"train"
,
"test"
}:
dp
=
resource_dps
[
0
]
# the train archive is a tar of tars
if
self
.
_split
==
"train"
:
dp
=
TarArchiveLoader
(
dp
)
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
dp
=
Mapper
(
dp
,
self
.
_prepare_train_data
if
self
.
_split
==
"train"
else
self
.
_prepare_test_data
)
else
:
# config.split == "val":
images_dp
,
devkit_dp
=
resource_dps
meta_dp
,
label_dp
=
Demultiplexer
(
devkit_dp
,
2
,
self
.
_classifiy_devkit
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
# We cannot use self._wnids here, since we use a different order than the dataset
meta_dp
=
CategoryAndWordNetIDExtractor
(
meta_dp
)
wnid_dp
=
Mapper
(
meta_dp
,
getitem
(
1
))
wnid_dp
=
Enumerator
(
wnid_dp
,
1
)
wnid_map
=
IterToMapConverter
(
wnid_dp
)
label_dp
=
LineReader
(
label_dp
,
decode
=
True
,
return_path
=
False
)
label_dp
=
Mapper
(
label_dp
,
int
)
label_dp
=
Mapper
(
label_dp
,
wnid_map
.
__getitem__
)
label_dp
:
IterDataPipe
[
Tuple
[
int
,
str
]]
=
Enumerator
(
label_dp
,
1
)
label_dp
=
hint_shuffling
(
label_dp
)
label_dp
=
hint_sharding
(
label_dp
)
dp
=
IterKeyZipper
(
label_dp
,
images_dp
,
key_fn
=
getitem
(
0
),
ref_key_fn
=
path_accessor
(
self
.
_val_test_image_key
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
dp
=
Mapper
(
dp
,
self
.
_prepare_val_data
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
{
"train"
:
1_281_167
,
"val"
:
50_000
,
"test"
:
100_000
,
}[
self
.
_split
]
def
_filter_meta
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
return
self
.
_classifiy_devkit
(
data
)
==
ImageNetDemux
.
META
def
_generate_categories
(
self
)
->
List
[
Tuple
[
str
,
...]]:
self
.
_split
=
"val"
resources
=
self
.
_resources
()
devkit_dp
=
resources
[
1
].
load
(
self
.
_root
)
meta_dp
=
Filter
(
devkit_dp
,
self
.
_filter_meta
)
meta_dp
=
CategoryAndWordNetIDExtractor
(
meta_dp
)
categories_and_wnids
=
cast
(
List
[
Tuple
[
str
,
...]],
list
(
meta_dp
))
categories_and_wnids
.
sort
(
key
=
lambda
category_and_wnid
:
category_and_wnid
[
1
])
return
categories_and_wnids
torchvision/prototype/datasets/_builtin/mnist.py
deleted
100644 → 0
View file @
f44f20cf
import
abc
import
functools
import
operator
import
pathlib
import
string
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
torchdata.datapipes.iter
import
Decompressor
,
Demultiplexer
,
IterDataPipe
,
Mapper
,
Zipper
from
torchvision.prototype.datasets.utils
import
Dataset
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.prototype.utils._internal
import
fromfile
from
torchvision.tv_tensors
import
Image
from
.._api
import
register_dataset
,
register_info
prod
=
functools
.
partial
(
functools
.
reduce
,
operator
.
mul
)
class
MNISTFileReader
(
IterDataPipe
[
torch
.
Tensor
]):
_DTYPE_MAP
=
{
8
:
torch
.
uint8
,
9
:
torch
.
int8
,
11
:
torch
.
int16
,
12
:
torch
.
int32
,
13
:
torch
.
float32
,
14
:
torch
.
float64
,
}
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Tuple
[
Any
,
BinaryIO
]],
*
,
start
:
Optional
[
int
],
stop
:
Optional
[
int
]
)
->
None
:
self
.
datapipe
=
datapipe
self
.
start
=
start
self
.
stop
=
stop
def
__iter__
(
self
)
->
Iterator
[
torch
.
Tensor
]:
for
_
,
file
in
self
.
datapipe
:
try
:
read
=
functools
.
partial
(
fromfile
,
file
,
byte_order
=
"big"
)
magic
=
int
(
read
(
dtype
=
torch
.
int32
,
count
=
1
))
dtype
=
self
.
_DTYPE_MAP
[
magic
//
256
]
ndim
=
magic
%
256
-
1
num_samples
=
int
(
read
(
dtype
=
torch
.
int32
,
count
=
1
))
shape
=
cast
(
List
[
int
],
read
(
dtype
=
torch
.
int32
,
count
=
ndim
).
tolist
())
if
ndim
else
[]
count
=
prod
(
shape
)
if
shape
else
1
start
=
self
.
start
or
0
stop
=
min
(
self
.
stop
,
num_samples
)
if
self
.
stop
else
num_samples
if
start
:
num_bytes_per_value
=
(
torch
.
finfo
if
dtype
.
is_floating_point
else
torch
.
iinfo
)(
dtype
).
bits
//
8
file
.
seek
(
num_bytes_per_value
*
count
*
start
,
1
)
for
_
in
range
(
stop
-
start
):
yield
read
(
dtype
=
dtype
,
count
=
count
).
reshape
(
shape
)
finally
:
file
.
close
()
class
_MNISTBase
(
Dataset
):
_URL_BASE
:
Union
[
str
,
Sequence
[
str
]]
@
abc
.
abstractmethod
def
_files_and_checksums
(
self
)
->
Tuple
[
Tuple
[
str
,
str
],
Tuple
[
str
,
str
]]:
pass
def
_resources
(
self
)
->
List
[
OnlineResource
]:
(
images_file
,
images_sha256
),
(
labels_file
,
labels_sha256
,
)
=
self
.
_files_and_checksums
()
url_bases
=
self
.
_URL_BASE
if
isinstance
(
url_bases
,
str
):
url_bases
=
(
url_bases
,)
images_urls
=
[
f
"
{
url_base
}
/
{
images_file
}
"
for
url_base
in
url_bases
]
images
=
HttpResource
(
images_urls
[
0
],
sha256
=
images_sha256
,
mirrors
=
images_urls
[
1
:])
labels_urls
=
[
f
"
{
url_base
}
/
{
labels_file
}
"
for
url_base
in
url_bases
]
labels
=
HttpResource
(
labels_urls
[
0
],
sha256
=
labels_sha256
,
mirrors
=
labels_urls
[
1
:])
return
[
images
,
labels
]
def
start_and_stop
(
self
)
->
Tuple
[
Optional
[
int
],
Optional
[
int
]]:
return
None
,
None
_categories
:
List
[
str
]
def
_prepare_sample
(
self
,
data
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
->
Dict
[
str
,
Any
]:
image
,
label
=
data
return
dict
(
image
=
Image
(
image
),
label
=
Label
(
label
,
dtype
=
torch
.
int64
,
categories
=
self
.
_categories
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
images_dp
,
labels_dp
=
resource_dps
start
,
stop
=
self
.
start_and_stop
()
images_dp
=
Decompressor
(
images_dp
)
images_dp
=
MNISTFileReader
(
images_dp
,
start
=
start
,
stop
=
stop
)
labels_dp
=
Decompressor
(
labels_dp
)
labels_dp
=
MNISTFileReader
(
labels_dp
,
start
=
start
,
stop
=
stop
)
dp
=
Zipper
(
images_dp
,
labels_dp
)
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
@
register_info
(
"mnist"
)
def
_mnist_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
[
str
(
label
)
for
label
in
range
(
10
)],
)
@
register_dataset
(
"mnist"
)
class
MNIST
(
_MNISTBase
):
"""
- **homepage**: http://yann.lecun.com/exdb/mnist
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"test"
))
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
_URL_BASE
:
Union
[
str
,
Sequence
[
str
]]
=
(
"http://yann.lecun.com/exdb/mnist"
,
"https://ossci-datasets.s3.amazonaws.com/mnist"
,
)
_CHECKSUMS
=
{
"train-images-idx3-ubyte.gz"
:
"440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
,
"train-labels-idx1-ubyte.gz"
:
"3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
,
"t10k-images-idx3-ubyte.gz"
:
"8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
,
"t10k-labels-idx1-ubyte.gz"
:
"f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
,
}
def
_files_and_checksums
(
self
)
->
Tuple
[
Tuple
[
str
,
str
],
Tuple
[
str
,
str
]]:
prefix
=
"train"
if
self
.
_split
==
"train"
else
"t10k"
images_file
=
f
"
{
prefix
}
-images-idx3-ubyte.gz"
labels_file
=
f
"
{
prefix
}
-labels-idx1-ubyte.gz"
return
(
images_file
,
self
.
_CHECKSUMS
[
images_file
]),
(
labels_file
,
self
.
_CHECKSUMS
[
labels_file
],
)
_categories
=
_mnist_info
()[
"categories"
]
def
__len__
(
self
)
->
int
:
return
60_000
if
self
.
_split
==
"train"
else
10_000
@
register_info
(
"fashionmnist"
)
def
_fashionmnist_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
[
"T-shirt/top"
,
"Trouser"
,
"Pullover"
,
"Dress"
,
"Coat"
,
"Sandal"
,
"Shirt"
,
"Sneaker"
,
"Bag"
,
"Ankle boot"
,
],
)
@
register_dataset
(
"fashionmnist"
)
class
FashionMNIST
(
MNIST
):
"""
- **homepage**: https://github.com/zalandoresearch/fashion-mnist
"""
_URL_BASE
=
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com"
_CHECKSUMS
=
{
"train-images-idx3-ubyte.gz"
:
"3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
,
"train-labels-idx1-ubyte.gz"
:
"a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
,
"t10k-images-idx3-ubyte.gz"
:
"346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
,
"t10k-labels-idx1-ubyte.gz"
:
"67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
,
}
_categories
=
_fashionmnist_info
()[
"categories"
]
@
register_info
(
"kmnist"
)
def
_kmnist_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
[
"o"
,
"ki"
,
"su"
,
"tsu"
,
"na"
,
"ha"
,
"ma"
,
"ya"
,
"re"
,
"wo"
],
)
@
register_dataset
(
"kmnist"
)
class
KMNIST
(
MNIST
):
"""
- **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en
"""
_URL_BASE
=
"http://codh.rois.ac.jp/kmnist/dataset/kmnist"
_CHECKSUMS
=
{
"train-images-idx3-ubyte.gz"
:
"51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
,
"train-labels-idx1-ubyte.gz"
:
"e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
,
"t10k-images-idx3-ubyte.gz"
:
"edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
,
"t10k-labels-idx1-ubyte.gz"
:
"20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
,
}
_categories
=
_kmnist_info
()[
"categories"
]
@
register_info
(
"emnist"
)
def
_emnist_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
list
(
string
.
digits
+
string
.
ascii_uppercase
+
string
.
ascii_lowercase
),
)
@
register_dataset
(
"emnist"
)
class
EMNIST
(
_MNISTBase
):
"""
- **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
image_set
:
str
=
"Balanced"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"test"
))
self
.
_image_set
=
self
.
_verify_str_arg
(
image_set
,
"image_set"
,
(
"Balanced"
,
"By_Merge"
,
"By_Class"
,
"Letters"
,
"Digits"
,
"MNIST"
)
)
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
_URL_BASE
=
"https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST"
def
_files_and_checksums
(
self
)
->
Tuple
[
Tuple
[
str
,
str
],
Tuple
[
str
,
str
]]:
prefix
=
f
"emnist-
{
self
.
_image_set
.
replace
(
'_'
,
''
).
lower
()
}
-
{
self
.
_split
}
"
images_file
=
f
"
{
prefix
}
-images-idx3-ubyte.gz"
labels_file
=
f
"
{
prefix
}
-labels-idx1-ubyte.gz"
# Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them
return
(
images_file
,
""
),
(
labels_file
,
""
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
return
[
HttpResource
(
f
"
{
self
.
_URL_BASE
}
/emnist-gzip.zip"
,
sha256
=
"909a2a39c5e86bdd7662425e9b9c4a49bb582bf8d0edad427f3c3a9d0c6f7259"
,
)
]
def
_classify_archive
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
path
=
pathlib
.
Path
(
data
[
0
])
(
images_file
,
_
),
(
labels_file
,
_
)
=
self
.
_files_and_checksums
()
if
path
.
name
==
images_file
:
return
0
elif
path
.
name
==
labels_file
:
return
1
else
:
return
None
_categories
=
_emnist_info
()[
"categories"
]
_LABEL_OFFSETS
=
{
38
:
1
,
39
:
1
,
40
:
1
,
41
:
1
,
42
:
1
,
43
:
6
,
44
:
8
,
45
:
8
,
46
:
9
,
}
def
_prepare_sample
(
self
,
data
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
->
Dict
[
str
,
Any
]:
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For
# example, since there is no 'c', 'd' corresponds to
# label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing),
# and at the same time corresponds to
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self._categories. Thus, we need to add 1 to the label to correct this.
if
self
.
_image_set
in
(
"Balanced"
,
"By_Merge"
):
image
,
label
=
data
label
+=
self
.
_LABEL_OFFSETS
.
get
(
int
(
label
),
0
)
data
=
(
image
,
label
)
return
super
().
_prepare_sample
(
data
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
archive_dp
=
resource_dps
[
0
]
images_dp
,
labels_dp
=
Demultiplexer
(
archive_dp
,
2
,
self
.
_classify_archive
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
super
().
_datapipe
([
images_dp
,
labels_dp
])
def
__len__
(
self
)
->
int
:
return
{
(
"train"
,
"Balanced"
):
112_800
,
(
"train"
,
"By_Merge"
):
697_932
,
(
"train"
,
"By_Class"
):
697_932
,
(
"train"
,
"Letters"
):
124_800
,
(
"train"
,
"Digits"
):
240_000
,
(
"train"
,
"MNIST"
):
60_000
,
(
"test"
,
"Balanced"
):
18_800
,
(
"test"
,
"By_Merge"
):
116_323
,
(
"test"
,
"By_Class"
):
116_323
,
(
"test"
,
"Letters"
):
20_800
,
(
"test"
,
"Digits"
):
40_000
,
(
"test"
,
"MNIST"
):
10_000
,
}[(
self
.
_split
,
self
.
_image_set
)]
@
register_info
(
"qmnist"
)
def
_qmnist_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
[
str
(
label
)
for
label
in
range
(
10
)],
)
@
register_dataset
(
"qmnist"
)
class
QMNIST
(
_MNISTBase
):
"""
- **homepage**: https://github.com/facebookresearch/qmnist
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"test"
,
"test10k"
,
"test50k"
,
"nist"
))
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
_URL_BASE
=
"https://raw.githubusercontent.com/facebookresearch/qmnist/master"
_CHECKSUMS
=
{
"qmnist-train-images-idx3-ubyte.gz"
:
"9e26a7bf1683614e065d7b76460ccd52807165b3f22561fb782bd9f38c52b51d"
,
"qmnist-train-labels-idx2-int.gz"
:
"2c05dc77f6b916b38e455e97ab129a42a444f3dbef09b278a366f82904e0dd9f"
,
"qmnist-test-images-idx3-ubyte.gz"
:
"43fc22bf7498b8fc98de98369d72f752d0deabc280a43a7bcc364ab19e57b375"
,
"qmnist-test-labels-idx2-int.gz"
:
"9fbcbe594c3766fdf4f0b15c5165dc0d1e57ac604e01422608bb72c906030d06"
,
"xnist-images-idx3-ubyte.xz"
:
"f075553993026d4359ded42208eff77a1941d3963c1eff49d6015814f15f0984"
,
"xnist-labels-idx2-int.xz"
:
"db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f"
,
}
def
_files_and_checksums
(
self
)
->
Tuple
[
Tuple
[
str
,
str
],
Tuple
[
str
,
str
]]:
prefix
=
"xnist"
if
self
.
_split
==
"nist"
else
f
"qmnist-
{
'train'
if
self
.
_split
==
'train'
else
'test'
}
"
suffix
=
"xz"
if
self
.
_split
==
"nist"
else
"gz"
images_file
=
f
"
{
prefix
}
-images-idx3-ubyte.
{
suffix
}
"
labels_file
=
f
"
{
prefix
}
-labels-idx2-int.
{
suffix
}
"
return
(
images_file
,
self
.
_CHECKSUMS
[
images_file
]),
(
labels_file
,
self
.
_CHECKSUMS
[
labels_file
],
)
def
start_and_stop
(
self
)
->
Tuple
[
Optional
[
int
],
Optional
[
int
]]:
start
:
Optional
[
int
]
stop
:
Optional
[
int
]
if
self
.
_split
==
"test10k"
:
start
=
0
stop
=
10000
elif
self
.
_split
==
"test50k"
:
start
=
10000
stop
=
None
else
:
start
=
stop
=
None
return
start
,
stop
_categories
=
_emnist_info
()[
"categories"
]
def
_prepare_sample
(
self
,
data
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
])
->
Dict
[
str
,
Any
]:
image
,
ann
=
data
label
,
*
extra_anns
=
ann
sample
=
super
().
_prepare_sample
((
image
,
label
))
sample
.
update
(
dict
(
zip
(
(
"nist_hsf_series"
,
"nist_writer_id"
,
"digit_index"
,
"nist_label"
,
"global_digit_index"
),
[
int
(
value
)
for
value
in
extra_anns
[:
5
]],
)
)
)
sample
.
update
(
dict
(
zip
((
"duplicate"
,
"unused"
),
[
bool
(
value
)
for
value
in
extra_anns
[
-
2
:]])))
return
sample
def
__len__
(
self
)
->
int
:
return
{
"train"
:
60_000
,
"test"
:
60_000
,
"test10k"
:
10_000
,
"test50k"
:
50_000
,
"nist"
:
402_953
,
}[
self
.
_split
]
torchvision/prototype/datasets/_builtin/oxford-iiit-pet.categories
deleted
100644 → 0
View file @
f44f20cf
Abyssinian
American Bulldog
American Pit Bull Terrier
Basset Hound
Beagle
Bengal
Birman
Bombay
Boxer
British Shorthair
Chihuahua
Egyptian Mau
English Cocker Spaniel
English Setter
German Shorthaired
Great Pyrenees
Havanese
Japanese Chin
Keeshond
Leonberger
Maine Coon
Miniature Pinscher
Newfoundland
Persian
Pomeranian
Pug
Ragdoll
Russian Blue
Saint Bernard
Samoyed
Scottish Terrier
Shiba Inu
Siamese
Sphynx
Staffordshire Bull Terrier
Wheaten Terrier
Yorkshire Terrier
torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py
deleted
100644 → 0
View file @
f44f20cf
import
enum
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
CSVDictParser
,
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
path_comparator
,
read_categories_file
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
.._api
import
register_dataset
,
register_info
NAME
=
"oxford-iiit-pet"
class
OxfordIIITPetDemux
(
enum
.
IntEnum
):
SPLIT_AND_CLASSIFICATION
=
0
SEGMENTATIONS
=
1
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
NAME
))
@
register_dataset
(
NAME
)
class
OxfordIIITPet
(
Dataset
):
"""Oxford IIIT Pet Dataset
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"trainval"
,
skip_integrity_check
:
bool
=
False
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
{
"trainval"
,
"test"
})
self
.
_categories
=
_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
images
=
HttpResource
(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
,
sha256
=
"67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d"
,
preprocess
=
"decompress"
,
)
anns
=
HttpResource
(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
,
sha256
=
"52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91"
,
preprocess
=
"decompress"
,
)
return
[
images
,
anns
]
def
_classify_anns
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
return
{
"annotations"
:
OxfordIIITPetDemux
.
SPLIT_AND_CLASSIFICATION
,
"trimaps"
:
OxfordIIITPetDemux
.
SEGMENTATIONS
,
}.
get
(
pathlib
.
Path
(
data
[
0
]).
parent
.
name
)
def
_filter_images
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
return
pathlib
.
Path
(
data
[
0
]).
suffix
==
".jpg"
def
_filter_segmentations
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
return
not
pathlib
.
Path
(
data
[
0
]).
name
.
startswith
(
"."
)
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
Dict
[
str
,
str
],
Tuple
[
str
,
BinaryIO
]],
Tuple
[
str
,
BinaryIO
]]
)
->
Dict
[
str
,
Any
]:
ann_data
,
image_data
=
data
classification_data
,
segmentation_data
=
ann_data
segmentation_path
,
segmentation_buffer
=
segmentation_data
image_path
,
image_buffer
=
image_data
return
dict
(
label
=
Label
(
int
(
classification_data
[
"label"
])
-
1
,
categories
=
self
.
_categories
),
species
=
"cat"
if
classification_data
[
"species"
]
==
"1"
else
"dog"
,
segmentation_path
=
segmentation_path
,
segmentation
=
EncodedImage
.
from_file
(
segmentation_buffer
),
image_path
=
image_path
,
image
=
EncodedImage
.
from_file
(
image_buffer
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
images_dp
,
anns_dp
=
resource_dps
images_dp
=
Filter
(
images_dp
,
self
.
_filter_images
)
split_and_classification_dp
,
segmentations_dp
=
Demultiplexer
(
anns_dp
,
2
,
self
.
_classify_anns
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
split_and_classification_dp
=
Filter
(
split_and_classification_dp
,
path_comparator
(
"name"
,
f
"
{
self
.
_split
}
.txt"
))
split_and_classification_dp
=
CSVDictParser
(
split_and_classification_dp
,
fieldnames
=
(
"image_id"
,
"label"
,
"species"
),
delimiter
=
" "
)
split_and_classification_dp
=
hint_shuffling
(
split_and_classification_dp
)
split_and_classification_dp
=
hint_sharding
(
split_and_classification_dp
)
segmentations_dp
=
Filter
(
segmentations_dp
,
self
.
_filter_segmentations
)
anns_dp
=
IterKeyZipper
(
split_and_classification_dp
,
segmentations_dp
,
key_fn
=
getitem
(
"image_id"
),
ref_key_fn
=
path_accessor
(
"stem"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
dp
=
IterKeyZipper
(
anns_dp
,
images_dp
,
key_fn
=
getitem
(
0
,
"image_id"
),
ref_key_fn
=
path_accessor
(
"stem"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
_filter_split_and_classification_anns
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
return
self
.
_classify_anns
(
data
)
==
OxfordIIITPetDemux
.
SPLIT_AND_CLASSIFICATION
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
1
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
self
.
_filter_split_and_classification_anns
)
dp
=
Filter
(
dp
,
path_comparator
(
"name"
,
"trainval.txt"
))
dp
=
CSVDictParser
(
dp
,
fieldnames
=
(
"image_id"
,
"label"
),
delimiter
=
" "
)
raw_categories_and_labels
=
{(
data
[
"image_id"
].
rsplit
(
"_"
,
1
)[
0
],
data
[
"label"
])
for
data
in
dp
}
raw_categories
,
_
=
zip
(
*
sorted
(
raw_categories_and_labels
,
key
=
lambda
raw_category_and_label
:
int
(
raw_category_and_label
[
1
]))
)
return
[
" "
.
join
(
part
.
title
()
for
part
in
raw_category
.
split
(
"_"
))
for
raw_category
in
raw_categories
]
def
__len__
(
self
)
->
int
:
return
3_680
if
self
.
_split
==
"trainval"
else
3_669
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment