Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b4686f2b
Unverified
Commit
b4686f2b
authored
Sep 13, 2022
by
Philip Meier
Committed by
GitHub
Sep 13, 2022
Browse files
Fully exhaust datapipes that are needed to construct a dataset (#6076)
parent
2c19af37
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
32 deletions
+42
-32
torchvision/prototype/datasets/_builtin/cub200.py
torchvision/prototype/datasets/_builtin/cub200.py
+8
-5
torchvision/prototype/datasets/_builtin/imagenet.py
torchvision/prototype/datasets/_builtin/imagenet.py
+34
-27
No files found.
torchvision/prototype/datasets/_builtin/cub200.py
View file @
b4686f2b
...
@@ -13,6 +13,7 @@ from torchdata.datapipes.iter import (
...
@@ -13,6 +13,7 @@ from torchdata.datapipes.iter import (
LineReader
,
LineReader
,
Mapper
,
Mapper
,
)
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.prototype.datasets.utils
import
Dataset
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -114,6 +115,9 @@ class CUB200(Dataset):
...
@@ -114,6 +115,9 @@ class CUB200(Dataset):
else
:
else
:
return
None
return
None
def
_2011_extract_file_name
(
self
,
rel_posix_path
:
str
)
->
str
:
return
rel_posix_path
.
rsplit
(
"/"
,
maxsplit
=
1
)[
1
]
def
_2011_filter_split
(
self
,
row
:
List
[
str
])
->
bool
:
def
_2011_filter_split
(
self
,
row
:
List
[
str
])
->
bool
:
_
,
split_id
=
row
_
,
split_id
=
row
return
{
return
{
...
@@ -185,17 +189,16 @@ class CUB200(Dataset):
...
@@ -185,17 +189,16 @@ class CUB200(Dataset):
)
)
image_files_dp
=
CSVParser
(
image_files_dp
,
dialect
=
"cub200"
)
image_files_dp
=
CSVParser
(
image_files_dp
,
dialect
=
"cub200"
)
image_files_map
=
dict
(
image_files_dp
=
Mapper
(
image_files_dp
,
self
.
_2011_extract_file_name
,
input_col
=
1
)
(
image_id
,
rel_posix_path
.
rsplit
(
"/"
,
maxsplit
=
1
)[
1
])
for
image_id
,
rel_posix_path
in
image_files_dp
image_files_map
=
IterToMapConverter
(
image_files_dp
)
)
split_dp
=
CSVParser
(
split_dp
,
dialect
=
"cub200"
)
split_dp
=
CSVParser
(
split_dp
,
dialect
=
"cub200"
)
split_dp
=
Filter
(
split_dp
,
self
.
_2011_filter_split
)
split_dp
=
Filter
(
split_dp
,
self
.
_2011_filter_split
)
split_dp
=
Mapper
(
split_dp
,
getitem
(
0
))
split_dp
=
Mapper
(
split_dp
,
getitem
(
0
))
split_dp
=
Mapper
(
split_dp
,
image_files_map
.
get
)
split_dp
=
Mapper
(
split_dp
,
image_files_map
.
__
get
item__
)
bounding_boxes_dp
=
CSVParser
(
bounding_boxes_dp
,
dialect
=
"cub200"
)
bounding_boxes_dp
=
CSVParser
(
bounding_boxes_dp
,
dialect
=
"cub200"
)
bounding_boxes_dp
=
Mapper
(
bounding_boxes_dp
,
image_files_map
.
get
,
input_col
=
0
)
bounding_boxes_dp
=
Mapper
(
bounding_boxes_dp
,
image_files_map
.
__
get
item__
,
input_col
=
0
)
anns_dp
=
IterKeyZipper
(
anns_dp
=
IterKeyZipper
(
bounding_boxes_dp
,
bounding_boxes_dp
,
...
...
torchvision/prototype/datasets/_builtin/imagenet.py
View file @
b4686f2b
import
enum
import
enum
import
functools
import
pathlib
import
pathlib
import
re
import
re
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
List
,
Match
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
Iterator
,
List
,
Match
,
Optional
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
(
from
torchdata.datapipes.iter
import
(
Demultiplexer
,
Demultiplexer
,
...
@@ -14,6 +14,7 @@ from torchdata.datapipes.iter import (
...
@@ -14,6 +14,7 @@ from torchdata.datapipes.iter import (
Mapper
,
Mapper
,
TarArchiveLoader
,
TarArchiveLoader
,
)
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.prototype.datasets.utils
import
Dataset
,
ManualDownloadResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
ManualDownloadResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum):
...
@@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum):
LABEL
=
1
LABEL
=
1
class
CategoryAndWordNetIDExtractor
(
IterDataPipe
):
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP
=
{
"n03126707"
:
"construction crane"
,
"n03710721"
:
"tank suit"
,
}
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Tuple
[
str
,
BinaryIO
]])
->
None
:
self
.
datapipe
=
datapipe
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
str
,
str
]]:
for
_
,
stream
in
self
.
datapipe
:
synsets
=
read_mat
(
stream
,
squeeze_me
=
True
)[
"synsets"
]
for
_
,
wnid
,
category
,
_
,
num_children
,
*
_
in
synsets
:
if
num_children
>
0
:
# we are looking at a superclass that has no direct instance
continue
yield
self
.
_WNID_MAP
.
get
(
wnid
,
category
.
split
(
","
,
1
)[
0
]),
wnid
@
register_dataset
(
NAME
)
@
register_dataset
(
NAME
)
class
ImageNet
(
Dataset
):
class
ImageNet
(
Dataset
):
"""
"""
...
@@ -110,25 +133,6 @@ class ImageNet(Dataset):
...
@@ -110,25 +133,6 @@ class ImageNet(Dataset):
"ILSVRC2012_validation_ground_truth.txt"
:
ImageNetDemux
.
LABEL
,
"ILSVRC2012_validation_ground_truth.txt"
:
ImageNetDemux
.
LABEL
,
}.
get
(
pathlib
.
Path
(
data
[
0
]).
name
)
}.
get
(
pathlib
.
Path
(
data
[
0
]).
name
)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP
=
{
"n03126707"
:
"construction crane"
,
"n03710721"
:
"tank suit"
,
}
def
_extract_categories_and_wnids
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
List
[
Tuple
[
str
,
str
]]:
synsets
=
read_mat
(
data
[
1
],
squeeze_me
=
True
)[
"synsets"
]
return
[
(
self
.
_WNID_MAP
.
get
(
wnid
,
category
.
split
(
","
,
1
)[
0
]),
wnid
)
for
_
,
wnid
,
category
,
_
,
num_children
,
*
_
in
synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if
num_children
==
0
]
def
_imagenet_label_to_wnid
(
self
,
imagenet_label
:
str
,
*
,
wnids
:
Tuple
[
str
,
...])
->
str
:
return
wnids
[
int
(
imagenet_label
)
-
1
]
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
def
_val_test_image_key
(
self
,
path
:
pathlib
.
Path
)
->
int
:
def
_val_test_image_key
(
self
,
path
:
pathlib
.
Path
)
->
int
:
...
@@ -172,12 +176,15 @@ class ImageNet(Dataset):
...
@@ -172,12 +176,15 @@ class ImageNet(Dataset):
devkit_dp
,
2
,
self
.
_classifiy_devkit
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
devkit_dp
,
2
,
self
.
_classifiy_devkit
,
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
)
meta_dp
=
Mapper
(
meta_dp
,
self
.
_extract_categories_and_wnids
)
# We cannot use self._wnids here, since we use a different order than the dataset
_
,
wnids
=
zip
(
*
next
(
iter
(
meta_dp
)))
meta_dp
=
CategoryAndWordNetIDExtractor
(
meta_dp
)
wnid_dp
=
Mapper
(
meta_dp
,
getitem
(
1
))
wnid_dp
=
Enumerator
(
wnid_dp
,
1
)
wnid_map
=
IterToMapConverter
(
wnid_dp
)
label_dp
=
LineReader
(
label_dp
,
decode
=
True
,
return_path
=
False
)
label_dp
=
LineReader
(
label_dp
,
decode
=
True
,
return_path
=
False
)
# We cannot use self._wnids here, since we use a different order than the dataset
label_dp
=
Mapper
(
label_dp
,
int
)
label_dp
=
Mapper
(
label_dp
,
functools
.
partial
(
self
.
_imagenet_label_to_wnid
,
wnids
=
wnids
)
)
label_dp
=
Mapper
(
label_dp
,
wnid_map
.
__getitem__
)
label_dp
:
IterDataPipe
[
Tuple
[
int
,
str
]]
=
Enumerator
(
label_dp
,
1
)
label_dp
:
IterDataPipe
[
Tuple
[
int
,
str
]]
=
Enumerator
(
label_dp
,
1
)
label_dp
=
hint_shuffling
(
label_dp
)
label_dp
=
hint_shuffling
(
label_dp
)
label_dp
=
hint_sharding
(
label_dp
)
label_dp
=
hint_sharding
(
label_dp
)
...
@@ -209,8 +216,8 @@ class ImageNet(Dataset):
...
@@ -209,8 +216,8 @@ class ImageNet(Dataset):
devkit_dp
=
resources
[
1
].
load
(
self
.
_root
)
devkit_dp
=
resources
[
1
].
load
(
self
.
_root
)
meta_dp
=
Filter
(
devkit_dp
,
self
.
_filter_meta
)
meta_dp
=
Filter
(
devkit_dp
,
self
.
_filter_meta
)
meta_dp
=
Mapper
(
meta_dp
,
self
.
_extract_categories_and_wnids
)
meta_dp
=
CategoryAndWordNetIDExtractor
(
meta_dp
)
categories_and_wnids
=
cast
(
List
[
Tuple
[
str
,
...]],
next
(
iter
(
meta_dp
))
)
categories_and_wnids
=
cast
(
List
[
Tuple
[
str
,
...]],
list
(
meta_dp
))
categories_and_wnids
.
sort
(
key
=
lambda
category_and_wnid
:
category_and_wnid
[
1
])
categories_and_wnids
.
sort
(
key
=
lambda
category_and_wnid
:
category_and_wnid
[
1
])
return
categories_and_wnids
return
categories_and_wnids
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment