Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2256b495
Unverified
Commit
2256b495
authored
Oct 07, 2021
by
Philip Meier
Committed by
GitHub
Oct 07, 2021
Browse files
improve prototype CIFAR implementation (#4558)
Co-authored-by:
Prabhat Roy
<
prabhatroy@fb.com
>
parent
a485b8c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
96 deletions
+43
-96
torchvision/prototype/datasets/_builtin/cifar.py
torchvision/prototype/datasets/_builtin/cifar.py
+43
-96
No files found.
torchvision/prototype/datasets/_builtin/cifar.py
View file @
2256b495
...
...
@@ -3,19 +3,17 @@ import functools
import
io
import
pathlib
import
pickle
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
TypeVa
r
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
Iterato
r
import
numpy
as
np
import
torch
from
torch.utils.data
import
IterDataPipe
from
torch.utils.data.datapipes.iter
import
(
Demultiplexer
,
Filter
,
Mapper
,
TarArchiveReader
,
Shuffler
,
)
from
torchdata.datapipes.iter
import
KeyZipper
from
torchvision.prototype.datasets.decoder
import
raw
from
torchvision.prototype.datasets.utils
import
(
Dataset
,
...
...
@@ -27,28 +25,35 @@ from torchvision.prototype.datasets.utils import (
)
from
torchvision.prototype.datasets.utils._internal
import
(
create_categories_file
,
MappingIterator
,
SequenceIterator
,
INFINITE_BUFFER_SIZE
,
image_buffer_from_array
,
Enumerator
,
getitem
,
path_comparator
,
)
__all__
=
[
"Cifar10"
,
"Cifar100"
]
HERE
=
pathlib
.
Path
(
__file__
).
parent
D
=
TypeVar
(
"D"
)
class
CifarFileReader
(
IterDataPipe
[
Tuple
[
np
.
ndarray
,
int
]]):
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Dict
[
str
,
Any
]],
*
,
labels_key
:
str
)
->
None
:
self
.
datapipe
=
datapipe
self
.
labels_key
=
labels_key
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
np
.
ndarray
,
int
]]:
for
mapping
in
self
.
datapipe
:
image_arrays
=
mapping
[
"data"
].
reshape
((
-
1
,
3
,
32
,
32
))
category_idcs
=
mapping
[
self
.
labels_key
]
yield
from
iter
(
zip
(
image_arrays
,
category_idcs
))
class
_CifarBase
(
Dataset
):
@
abc
.
abstractmethod
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
],
*
,
config
:
DatasetConfig
)
->
Optional
[
int
]:
pass
_LABELS_KEY
:
str
_META_FILE_NAME
:
str
_CATEGORIES_KEY
:
str
@
abc
.
abstractmethod
def
_s
plit
_data_file
(
self
,
data
:
Tuple
[
str
,
Any
]
)
->
Optional
[
int
]:
def
_
i
s_data_file
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
],
*
,
config
:
DatasetConfig
)
->
Optional
[
int
]:
pass
def
_unpickle
(
self
,
data
:
Tuple
[
str
,
io
.
BytesIO
])
->
Dict
[
str
,
Any
]:
...
...
@@ -57,21 +62,20 @@ class _CifarBase(Dataset):
def
_collate_and_decode
(
self
,
data
:
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
np
.
ndarray
]
],
data
:
Tuple
[
np
.
ndarray
,
int
],
*
,
decoder
:
Optional
[
Callable
[[
io
.
IOBase
],
torch
.
Tensor
]],
)
->
Dict
[
str
,
Any
]:
(
_
,
category_idx
),
(
_
,
image_array_flat
)
=
data
image_array
,
category_idx
=
data
category
=
self
.
categories
[
category_idx
]
label
=
torch
.
tensor
(
category_idx
)
image_array
=
image_array_flat
.
reshape
((
3
,
32
,
32
))
image
:
Union
[
torch
.
Tensor
,
io
.
BytesIO
]
if
decoder
is
raw
:
image
=
torch
.
from_numpy
(
image_array
)
else
:
image_buffer
=
image_buffer_from_array
(
image_array
.
transpose
(
1
,
2
,
0
))
image_buffer
=
image_buffer_from_array
(
image_array
.
transpose
(
(
1
,
2
,
0
))
)
image
=
decoder
(
image_buffer
)
if
decoder
else
image_buffer
return
dict
(
label
=
label
,
category
=
category
,
image
=
image
)
...
...
@@ -83,55 +87,32 @@ class _CifarBase(Dataset):
config
:
DatasetConfig
,
decoder
:
Optional
[
Callable
[[
io
.
IOBase
],
torch
.
Tensor
]],
)
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
archive_dp
=
resource_dps
[
0
]
archive_dp
=
TarArchiveReader
(
archive_dp
)
archive_dp
:
IterDataPipe
=
Filter
(
archive_dp
,
functools
.
partial
(
self
.
_is_data_file
,
config
=
config
))
archive_dp
:
IterDataPipe
=
Mapper
(
archive_dp
,
self
.
_unpickle
)
archive_dp
=
MappingIterator
(
archive_dp
)
images_dp
,
labels_dp
=
Demultiplexer
(
archive_dp
,
2
,
self
.
_split_data_file
,
# type: ignore[arg-type]
drop_none
=
True
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
labels_dp
:
IterDataPipe
=
Mapper
(
labels_dp
,
getitem
(
1
))
labels_dp
:
IterDataPipe
=
SequenceIterator
(
labels_dp
)
labels_dp
=
Enumerator
(
labels_dp
)
labels_dp
=
Shuffler
(
labels_dp
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
images_dp
:
IterDataPipe
=
Mapper
(
images_dp
,
getitem
(
1
))
images_dp
:
IterDataPipe
=
SequenceIterator
(
images_dp
)
images_dp
=
Enumerator
(
images_dp
)
dp
=
KeyZipper
(
labels_dp
,
images_dp
,
getitem
(
0
),
buffer_size
=
INFINITE_BUFFER_SIZE
)
dp
=
resource_dps
[
0
]
dp
:
IterDataPipe
=
TarArchiveReader
(
dp
)
dp
:
IterDataPipe
=
Filter
(
dp
,
functools
.
partial
(
self
.
_is_data_file
,
config
=
config
))
dp
:
IterDataPipe
=
Mapper
(
dp
,
self
.
_unpickle
)
dp
=
CifarFileReader
(
dp
,
labels_key
=
self
.
_LABELS_KEY
)
dp
=
Shuffler
(
dp
,
buffer_size
=
INFINITE_BUFFER_SIZE
)
return
Mapper
(
dp
,
self
.
_collate_and_decode
,
fn_kwargs
=
dict
(
decoder
=
decoder
))
@
property
@
abc
.
abstractmethod
def
_meta_file_name
(
self
)
->
str
:
pass
@
property
@
abc
.
abstractmethod
def
_categories_key
(
self
)
->
str
:
pass
def
_is_meta_file
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
==
self
.
_meta_file_name
def
generate_categories_file
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
])
->
None
:
dp
=
self
.
resources
(
self
.
default_config
)[
0
].
to_datapipe
(
pathlib
.
Path
(
root
)
/
self
.
name
)
dp
=
TarArchiveReader
(
dp
)
dp
:
IterDataPipe
=
Filter
(
dp
,
self
.
_is_meta_file
)
dp
:
IterDataPipe
=
Filter
(
dp
,
path_comparator
(
"name"
,
self
.
_META_FILE_NAME
)
)
dp
:
IterDataPipe
=
Mapper
(
dp
,
self
.
_unpickle
)
categories
=
next
(
iter
(
dp
))[
self
.
_
categories_key
]
categories
=
next
(
iter
(
dp
))[
self
.
_
CATEGORIES_KEY
]
create_categories_file
(
HERE
,
self
.
name
,
categories
)
class
Cifar10
(
_CifarBase
):
_LABELS_KEY
=
"labels"
_META_FILE_NAME
=
"batches.meta"
_CATEGORIES_KEY
=
"label_names"
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
Any
],
*
,
config
:
DatasetConfig
)
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
.
startswith
(
"data"
if
config
.
split
==
"train"
else
"test"
)
@
property
def
info
(
self
)
->
DatasetInfo
:
return
DatasetInfo
(
...
...
@@ -149,29 +130,16 @@ class Cifar10(_CifarBase):
)
]
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
Any
],
*
,
config
:
DatasetConfig
)
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
.
startswith
(
"data"
if
config
.
split
==
"train"
else
"test"
)
def
_split_data_file
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
key
,
_
=
data
if
key
==
"data"
:
return
0
elif
key
==
"labels"
:
return
1
else
:
return
None
@
property
def
_meta_file_name
(
self
)
->
str
:
return
"batches.meta"
@
property
def
_categories_key
(
self
)
->
str
:
return
"label_names"
class
Cifar100
(
_CifarBase
):
_LABELS_KEY
=
"fine_labels"
_META_FILE_NAME
=
"meta"
_CATEGORIES_KEY
=
"fine_label_names"
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
],
*
,
config
:
DatasetConfig
)
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
==
config
.
split
class
Cifar100
(
_CifarBase
):
@
property
def
info
(
self
)
->
DatasetInfo
:
return
DatasetInfo
(
...
...
@@ -192,27 +160,6 @@ class Cifar100(_CifarBase):
)
]
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
],
*
,
config
:
DatasetConfig
)
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
==
config
.
split
def
_split_data_file
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Optional
[
int
]:
key
,
_
=
data
if
key
==
"data"
:
return
0
elif
key
==
"fine_labels"
:
return
1
else
:
return
None
@
property
def
_meta_file_name
(
self
)
->
str
:
return
"meta"
@
property
def
_categories_key
(
self
)
->
str
:
return
"fine_label_names"
if
__name__
==
"__main__"
:
from
torchvision.prototype.datasets
import
home
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment