Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
48b1edff
Unverified
Commit
48b1edff
authored
Jun 14, 2024
by
Nicolas Hug
Committed by
GitHub
Jun 14, 2024
Browse files
Remove prototype area for 0.19 (#8491)
parent
f44f20cf
Changes
74
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1 addition
and
4132 deletions
+1
-4132
.github/workflows/prototype-tests-linux-gpu.yml
.github/workflows/prototype-tests-linux-gpu.yml
+0
-58
mypy.ini
mypy.ini
+0
-22
setup.py
setup.py
+1
-1
test/builtin_dataset_mocks.py
test/builtin_dataset_mocks.py
+0
-1582
test/test_prototype_datasets_builtin.py
test/test_prototype_datasets_builtin.py
+0
-282
test/test_prototype_datasets_utils.py
test/test_prototype_datasets_utils.py
+0
-302
test/test_prototype_models.py
test/test_prototype_models.py
+0
-84
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+0
-429
torchvision/prototype/__init__.py
torchvision/prototype/__init__.py
+0
-1
torchvision/prototype/datasets/README.md
torchvision/prototype/datasets/README.md
+0
-7
torchvision/prototype/datasets/__init__.py
torchvision/prototype/datasets/__init__.py
+0
-15
torchvision/prototype/datasets/_api.py
torchvision/prototype/datasets/_api.py
+0
-65
torchvision/prototype/datasets/_builtin/README.md
torchvision/prototype/datasets/_builtin/README.md
+0
-340
torchvision/prototype/datasets/_builtin/__init__.py
torchvision/prototype/datasets/_builtin/__init__.py
+0
-22
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+0
-212
torchvision/prototype/datasets/_builtin/caltech101.categories
...hvision/prototype/datasets/_builtin/caltech101.categories
+0
-101
torchvision/prototype/datasets/_builtin/caltech256.categories
...hvision/prototype/datasets/_builtin/caltech256.categories
+0
-257
torchvision/prototype/datasets/_builtin/celeba.py
torchvision/prototype/datasets/_builtin/celeba.py
+0
-200
torchvision/prototype/datasets/_builtin/cifar.py
torchvision/prototype/datasets/_builtin/cifar.py
+0
-142
torchvision/prototype/datasets/_builtin/cifar10.categories
torchvision/prototype/datasets/_builtin/cifar10.categories
+0
-10
No files found.
.github/workflows/prototype-tests-linux-gpu.yml
deleted
100644 → 0
View file @
f44f20cf
name
:
Prototype tests on Linux
# IMPORTANT: This workflow has been manually disabled from the GitHub interface
# in June 2024. The file is kept for reference in case we ever put this back.
on
:
pull_request
:
jobs
:
unittests-prototype
:
strategy
:
matrix
:
python-version
:
-
"
3.8"
-
"
3.9"
-
"
3.10"
-
"
3.11"
-
"
3.12"
runner
:
[
"
linux.12xlarge"
]
gpu-arch-type
:
[
"
cpu"
]
include
:
-
python-version
:
"
3.8"
runner
:
linux.g5.4xlarge.nvidia.gpu
gpu-arch-type
:
cuda
gpu-arch-version
:
"
11.8"
fail-fast
:
false
uses
:
pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4
with
:
repository
:
pytorch/vision
runner
:
${{ matrix.runner }}
gpu-arch-type
:
${{ matrix.gpu-arch-type }}
gpu-arch-version
:
${{ matrix.gpu-arch-version }}
timeout
:
120
script
:
|
set -euo pipefail
export PYTHON_VERSION=${{ matrix.python-version }}
export GPU_ARCH_TYPE=${{ matrix.gpu-arch-type }}
export GPU_ARCH_VERSION=${{ matrix.gpu-arch-version }}
./.github/scripts/setup-env.sh
# Prepare conda
CONDA_PATH=$(which conda)
eval "$(${CONDA_PATH} shell.bash hook)"
conda activate ci
echo '::group::Install testing utilities'
pip install --progress-bar=off pytest pytest-mock pytest-cov
echo '::endgroup::'
# We don't want to run the prototype datasets tests. Since the positional glob into `pytest`, i.e.
# `test/test_prototype*.py` takes the highest priority, neither `--ignore` nor `--ignore-glob` can help us here.
rm test/test_prototype_datasets*.py
pytest \
-v --durations=25 \
--cov=torchvision/prototype --cov-report=term-missing \
--junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" \
test/test_prototype_*.py
mypy.ini
View file @
48b1edff
...
@@ -7,28 +7,6 @@ allow_redefinition = True
...
@@ -7,28 +7,6 @@ allow_redefinition = True
no_implicit_optional
=
True
no_implicit_optional
=
True
warn_redundant_casts
=
True
warn_redundant_casts
=
True
[mypy-torchvision.prototype.datapoints.*]
; untyped definitions and calls
disallow_untyped_defs
=
True
; None and Optional handling
no_implicit_optional
=
True
; warnings
warn_unused_ignores
=
True
; miscellaneous strictness flags
allow_redefinition
=
True
[mypy-torchvision.prototype.transforms.*]
ignore_errors
=
True
[mypy-torchvision.prototype.datasets.*]
ignore_errors
=
True
[mypy-torchvision.io.image.*]
[mypy-torchvision.io.image.*]
ignore_errors
=
True
ignore_errors
=
True
...
...
setup.py
View file @
48b1edff
...
@@ -550,7 +550,7 @@ if __name__ == "__main__":
...
@@ -550,7 +550,7 @@ if __name__ == "__main__":
license
=
"BSD"
,
license
=
"BSD"
,
# Package info
# Package info
packages
=
find_packages
(
exclude
=
(
"test"
,)),
packages
=
find_packages
(
exclude
=
(
"test"
,)),
package_data
=
{
package_name
:
[
"*.dll"
,
"*.dylib"
,
"*.so"
,
"prototype/datasets/_builtin/*.categories"
]},
package_data
=
{
package_name
:
[
"*.dll"
,
"*.dylib"
,
"*.so"
]},
zip_safe
=
False
,
zip_safe
=
False
,
install_requires
=
requirements
,
install_requires
=
requirements
,
extras_require
=
{
extras_require
=
{
...
...
test/builtin_dataset_mocks.py
deleted
100644 → 0
View file @
f44f20cf
This diff is collapsed.
Click to expand it.
test/test_prototype_datasets_builtin.py
deleted
100644 → 0
View file @
f44f20cf
import
io
import
pickle
from
collections
import
deque
from
pathlib
import
Path
import
pytest
import
torch
import
torchvision.transforms.v2
as
transforms
from
builtin_dataset_mocks
import
DATASET_MOCKS
,
parametrize_dataset_mocks
from
torch.testing._comparison
import
not_close_error_metas
,
ObjectPair
,
TensorLikePair
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
from
torch.utils.data
import
DataLoader
# TODO: replace with torchdata equivalent as soon as it is available
from
torch.utils.data.graph_settings
import
get_all_graph_pipes
from
torchdata.dataloader2.graph.utils
import
traverse_dps
from
torchdata.datapipes.iter
import
ShardingFilter
,
Shuffler
from
torchdata.datapipes.utils
import
StreamWrapper
from
torchvision
import
tv_tensors
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
datasets
from
torchvision.prototype.datasets.utils
import
EncodedImage
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.transforms.v2._utils
import
is_pure_tensor
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
error_metas
=
not_close_error_metas
(
*
args
,
pair_types
=
(
TensorLikePair
,
ObjectPair
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
,
**
kwargs
)
if
error_metas
:
raise
error_metas
[
0
].
to_error
(
msg
)
def
extract_datapipes
(
dp
):
return
get_all_graph_pipes
(
traverse_dps
(
dp
))
def
consume
(
iterator
):
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
deque
(
iterator
,
maxlen
=
0
)
def
next_consume
(
iterator
):
item
=
next
(
iterator
)
consume
(
iterator
)
return
item
@
pytest
.
fixture
(
autouse
=
True
)
def
test_home
(
mocker
,
tmp_path
):
mocker
.
patch
(
"torchvision.prototype.datasets._api.home"
,
return_value
=
str
(
tmp_path
))
mocker
.
patch
(
"torchvision.prototype.datasets.home"
,
return_value
=
str
(
tmp_path
))
yield
tmp_path
def
test_coverage
():
untested_datasets
=
set
(
datasets
.
list_datasets
())
-
DATASET_MOCKS
.
keys
()
if
untested_datasets
:
raise
AssertionError
(
f
"The dataset(s)
{
sequence_to_str
(
sorted
(
untested_datasets
),
separate_last
=
'and '
)
}
"
f
"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. "
f
"Please add mock data to `test/builtin_dataset_mocks.py`."
)
@
pytest
.
mark
.
filterwarnings
(
"error"
)
class
TestCommon
:
@
pytest
.
mark
.
parametrize
(
"name"
,
datasets
.
list_datasets
())
def
test_info
(
self
,
name
):
try
:
info
=
datasets
.
info
(
name
)
except
ValueError
:
raise
AssertionError
(
"No info available."
)
from
None
if
not
(
isinstance
(
info
,
dict
)
and
all
(
isinstance
(
key
,
str
)
for
key
in
info
.
keys
())):
raise
AssertionError
(
"Info should be a dictionary with string keys."
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_smoke
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
if
not
isinstance
(
dataset
,
datasets
.
utils
.
Dataset
):
raise
AssertionError
(
f
"Loading the dataset should return an Dataset, but got
{
type
(
dataset
)
}
instead."
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_sample
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
try
:
sample
=
next_consume
(
iter
(
dataset
))
except
StopIteration
:
raise
AssertionError
(
"Unable to draw any sample."
)
from
None
except
Exception
as
error
:
raise
AssertionError
(
"Drawing a sample raised the error above."
)
from
error
if
not
isinstance
(
sample
,
dict
):
raise
AssertionError
(
f
"Samples should be dictionaries, but got
{
type
(
sample
)
}
instead."
)
if
not
sample
:
raise
AssertionError
(
"Sample dictionary is empty."
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_num_samples
(
self
,
dataset_mock
,
config
):
dataset
,
mock_info
=
dataset_mock
.
load
(
config
)
assert
len
(
list
(
dataset
))
==
mock_info
[
"num_samples"
]
@
pytest
.
fixture
def
log_session_streams
(
self
):
debug_unclosed_streams
=
StreamWrapper
.
debug_unclosed_streams
try
:
StreamWrapper
.
debug_unclosed_streams
=
True
yield
finally
:
StreamWrapper
.
debug_unclosed_streams
=
debug_unclosed_streams
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_stream_closing
(
self
,
log_session_streams
,
dataset_mock
,
config
):
def
make_msg_and_close
(
head
):
unclosed_streams
=
[]
for
stream
in
list
(
StreamWrapper
.
session_streams
.
keys
()):
unclosed_streams
.
append
(
repr
(
stream
.
file_obj
))
stream
.
close
()
unclosed_streams
=
"
\n
"
.
join
(
unclosed_streams
)
return
f
"
{
head
}
\n\n
{
unclosed_streams
}
"
if
StreamWrapper
.
session_streams
:
raise
pytest
.
UsageError
(
make_msg_and_close
(
"A previous test did not close the following streams:"
))
dataset
,
_
=
dataset_mock
.
load
(
config
)
consume
(
iter
(
dataset
))
if
StreamWrapper
.
session_streams
:
raise
AssertionError
(
make_msg_and_close
(
"The following streams were not closed after a full iteration:"
))
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_no_unaccompanied_pure_tensors
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
sample
=
next_consume
(
iter
(
dataset
))
pure_tensors
=
{
key
for
key
,
value
in
sample
.
items
()
if
is_pure_tensor
(
value
)}
if
pure_tensors
and
not
any
(
isinstance
(
item
,
(
tv_tensors
.
Image
,
tv_tensors
.
Video
,
EncodedImage
))
for
item
in
sample
.
values
()
):
raise
AssertionError
(
f
"The values of key(s) "
f
"
{
sequence_to_str
(
sorted
(
pure_tensors
),
separate_last
=
'and '
)
}
contained pure tensors, "
f
"but didn't find any (encoded) image or video."
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_transformable
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
dataset
=
dataset
.
map
(
transforms
.
Identity
())
consume
(
iter
(
dataset
))
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_traversable
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
traverse_dps
(
dataset
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_serializable
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
pickle
.
dumps
(
dataset
)
# This has to be a proper function, since lambda's or local functions
# cannot be pickled, but this is a requirement for the DataLoader with
# multiprocessing, i.e. num_workers > 0
def
_collate_fn
(
self
,
batch
):
return
batch
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
1
])
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_data_loader
(
self
,
dataset_mock
,
config
,
num_workers
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
dl
=
DataLoader
(
dataset
,
batch_size
=
2
,
num_workers
=
num_workers
,
collate_fn
=
self
.
_collate_fn
,
)
consume
(
dl
)
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
@
pytest
.
mark
.
parametrize
(
"annotation_dp_type"
,
(
Shuffler
,
ShardingFilter
))
def
test_has_annotations
(
self
,
dataset_mock
,
config
,
annotation_dp_type
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
if
not
any
(
isinstance
(
dp
,
annotation_dp_type
)
for
dp
in
extract_datapipes
(
dataset
)):
raise
AssertionError
(
f
"The dataset doesn't contain a
{
annotation_dp_type
.
__name__
}
() datapipe."
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_save_load
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
sample
=
next_consume
(
iter
(
dataset
))
with
io
.
BytesIO
()
as
buffer
:
torch
.
save
(
sample
,
buffer
)
buffer
.
seek
(
0
)
assert_samples_equal
(
torch
.
load
(
buffer
,
weights_only
=
True
),
sample
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_infinite_buffer_size
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
for
dp
in
extract_datapipes
(
dataset
):
if
hasattr
(
dp
,
"buffer_size"
):
# TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is
# resolved
assert
dp
.
buffer_size
==
INFINITE_BUFFER_SIZE
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_has_length
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
assert
len
(
dataset
)
>
0
@
parametrize_dataset_mocks
(
DATASET_MOCKS
[
"qmnist"
])
class
TestQMNIST
:
def
test_extra_label
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
sample
=
next_consume
(
iter
(
dataset
))
for
key
,
type
in
(
(
"nist_hsf_series"
,
int
),
(
"nist_writer_id"
,
int
),
(
"digit_index"
,
int
),
(
"nist_label"
,
int
),
(
"global_digit_index"
,
int
),
(
"duplicate"
,
bool
),
(
"unused"
,
bool
),
):
assert
key
in
sample
and
isinstance
(
sample
[
key
],
type
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
[
"gtsrb"
])
class
TestGTSRB
:
def
test_label_matches_path
(
self
,
dataset_mock
,
config
):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if
config
[
"split"
]
!=
"train"
:
return
dataset
,
_
=
dataset_mock
.
load
(
config
)
for
sample
in
dataset
:
label_from_path
=
int
(
Path
(
sample
[
"path"
]).
parent
.
name
)
assert
sample
[
"label"
]
==
label_from_path
@
parametrize_dataset_mocks
(
DATASET_MOCKS
[
"usps"
])
class
TestUSPS
:
def
test_sample_content
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
for
sample
in
dataset
:
assert
"image"
in
sample
assert
"label"
in
sample
assert
isinstance
(
sample
[
"image"
],
tv_tensors
.
Image
)
assert
isinstance
(
sample
[
"label"
],
Label
)
assert
sample
[
"image"
].
shape
==
(
1
,
16
,
16
)
test/test_prototype_datasets_utils.py
deleted
100644 → 0
View file @
f44f20cf
import
gzip
import
pathlib
import
sys
import
numpy
as
np
import
pytest
import
torch
from
datasets_utils
import
make_fake_flo_file
,
make_tar
from
torchdata.datapipes.iter
import
FileOpener
,
TarArchiveLoader
from
torchvision.datasets._optical_flow
import
_read_flo
as
read_flo_ref
from
torchvision.datasets.utils
import
_decompress
from
torchvision.prototype.datasets.utils
import
Dataset
,
GDriveResource
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
fromfile
,
read_flo
@
pytest
.
mark
.
filterwarnings
(
"error:The given NumPy array is not writeable:UserWarning"
)
@
pytest
.
mark
.
parametrize
(
(
"np_dtype"
,
"torch_dtype"
,
"byte_order"
),
[
(
">f4"
,
torch
.
float32
,
"big"
),
(
"<f8"
,
torch
.
float64
,
"little"
),
(
"<i4"
,
torch
.
int32
,
"little"
),
(
">i8"
,
torch
.
int64
,
"big"
),
(
"|u1"
,
torch
.
uint8
,
sys
.
byteorder
),
],
)
@
pytest
.
mark
.
parametrize
(
"count"
,
(
-
1
,
2
))
@
pytest
.
mark
.
parametrize
(
"mode"
,
(
"rb"
,
"r+b"
))
def
test_fromfile
(
tmpdir
,
np_dtype
,
torch_dtype
,
byte_order
,
count
,
mode
):
path
=
tmpdir
/
"data.bin"
rng
=
np
.
random
.
RandomState
(
0
)
rng
.
randn
(
5
if
count
==
-
1
else
count
+
1
).
astype
(
np_dtype
).
tofile
(
path
)
for
count_
in
(
-
1
,
count
//
2
):
expected
=
torch
.
from_numpy
(
np
.
fromfile
(
path
,
dtype
=
np_dtype
,
count
=
count_
).
astype
(
np_dtype
[
1
:]))
with
open
(
path
,
mode
)
as
file
:
actual
=
fromfile
(
file
,
dtype
=
torch_dtype
,
byte_order
=
byte_order
,
count
=
count_
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
test_read_flo
(
tmpdir
):
path
=
tmpdir
/
"test.flo"
make_fake_flo_file
(
3
,
4
,
path
)
with
open
(
path
,
"rb"
)
as
file
:
actual
=
read_flo
(
file
)
expected
=
torch
.
from_numpy
(
read_flo_ref
(
path
).
astype
(
"f4"
,
copy
=
False
))
torch
.
testing
.
assert_close
(
actual
,
expected
)
class
TestOnlineResource
:
class
DummyResource
(
OnlineResource
):
def
__init__
(
self
,
download_fn
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_download_fn
=
download_fn
def
_download
(
self
,
root
):
if
self
.
_download_fn
is
None
:
raise
pytest
.
UsageError
(
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`."
)
return
self
.
_download_fn
(
self
,
root
)
def
_make_file
(
self
,
root
,
*
,
content
,
name
=
"file.txt"
):
file
=
root
/
name
with
open
(
file
,
"w"
)
as
fh
:
fh
.
write
(
content
)
return
file
def
_make_folder
(
self
,
root
,
*
,
name
=
"folder"
):
folder
=
root
/
name
subfolder
=
folder
/
"subfolder"
subfolder
.
mkdir
(
parents
=
True
)
files
=
{}
for
idx
,
root
in
enumerate
([
folder
,
folder
,
subfolder
]):
content
=
f
"sentinel
{
idx
}
"
file
=
self
.
_make_file
(
root
,
name
=
f
"file
{
idx
}
.txt"
,
content
=
content
)
files
[
str
(
file
)]
=
content
return
folder
,
files
def
_make_tar
(
self
,
root
,
*
,
name
=
"archive.tar"
,
remove
=
True
):
folder
,
files
=
self
.
_make_folder
(
root
,
name
=
name
.
split
(
"."
)[
0
])
archive
=
make_tar
(
root
,
name
,
folder
,
remove
=
remove
)
files
=
{
str
(
archive
/
pathlib
.
Path
(
file
).
relative_to
(
root
)):
content
for
file
,
content
in
files
.
items
()}
return
archive
,
files
def
test_load_file
(
self
,
tmp_path
):
content
=
"sentinel"
file
=
self
.
_make_file
(
tmp_path
,
content
=
content
)
resource
=
self
.
DummyResource
(
file_name
=
file
.
name
)
dp
=
resource
.
load
(
tmp_path
)
assert
isinstance
(
dp
,
FileOpener
)
data
=
list
(
dp
)
assert
len
(
data
)
==
1
path
,
buffer
=
data
[
0
]
assert
path
==
str
(
file
)
assert
buffer
.
read
().
decode
()
==
content
def
test_load_folder
(
self
,
tmp_path
):
folder
,
files
=
self
.
_make_folder
(
tmp_path
)
resource
=
self
.
DummyResource
(
file_name
=
folder
.
name
)
dp
=
resource
.
load
(
tmp_path
)
assert
isinstance
(
dp
,
FileOpener
)
assert
{
path
:
buffer
.
read
().
decode
()
for
path
,
buffer
in
dp
}
==
files
def
test_load_archive
(
self
,
tmp_path
):
archive
,
files
=
self
.
_make_tar
(
tmp_path
)
resource
=
self
.
DummyResource
(
file_name
=
archive
.
name
)
dp
=
resource
.
load
(
tmp_path
)
assert
isinstance
(
dp
,
TarArchiveLoader
)
assert
{
path
:
buffer
.
read
().
decode
()
for
path
,
buffer
in
dp
}
==
files
def
test_priority_decompressed_gt_raw
(
self
,
tmp_path
):
# We don't need to actually compress here. Adding the suffix is sufficient
self
.
_make_file
(
tmp_path
,
content
=
"raw_sentinel"
,
name
=
"file.txt.gz"
)
file
=
self
.
_make_file
(
tmp_path
,
content
=
"decompressed_sentinel"
,
name
=
"file.txt"
)
resource
=
self
.
DummyResource
(
file_name
=
file
.
name
)
dp
=
resource
.
load
(
tmp_path
)
path
,
buffer
=
next
(
iter
(
dp
))
assert
path
==
str
(
file
)
assert
buffer
.
read
().
decode
()
==
"decompressed_sentinel"
def
test_priority_extracted_gt_decompressed
(
self
,
tmp_path
):
archive
,
_
=
self
.
_make_tar
(
tmp_path
,
remove
=
False
)
resource
=
self
.
DummyResource
(
file_name
=
archive
.
name
)
dp
=
resource
.
load
(
tmp_path
)
# If the archive had been selected, this would be a `TarArchiveReader`
assert
isinstance
(
dp
,
FileOpener
)
def
test_download
(
self
,
tmp_path
):
download_fn_was_called
=
False
def
download_fn
(
resource
,
root
):
nonlocal
download_fn_was_called
download_fn_was_called
=
True
return
self
.
_make_file
(
root
,
content
=
"_"
,
name
=
resource
.
file_name
)
resource
=
self
.
DummyResource
(
file_name
=
"file.txt"
,
download_fn
=
download_fn
,
)
resource
.
load
(
tmp_path
)
assert
download_fn_was_called
,
"`download_fn()` was never called"
# This tests the `"decompress"` literal as well as a custom callable
@
pytest
.
mark
.
parametrize
(
"preprocess"
,
[
"decompress"
,
lambda
path
:
_decompress
(
str
(
path
),
remove_finished
=
True
),
],
)
def
test_preprocess_decompress
(
self
,
tmp_path
,
preprocess
):
file_name
=
"file.txt.gz"
content
=
"sentinel"
def
download_fn
(
resource
,
root
):
file
=
root
/
resource
.
file_name
with
gzip
.
open
(
file
,
"wb"
)
as
fh
:
fh
.
write
(
content
.
encode
())
return
file
resource
=
self
.
DummyResource
(
file_name
=
file_name
,
preprocess
=
preprocess
,
download_fn
=
download_fn
)
dp
=
resource
.
load
(
tmp_path
)
data
=
list
(
dp
)
assert
len
(
data
)
==
1
path
,
buffer
=
data
[
0
]
assert
path
==
str
(
tmp_path
/
file_name
).
replace
(
".gz"
,
""
)
assert
buffer
.
read
().
decode
()
==
content
def
test_preprocess_extract
(
self
,
tmp_path
):
files
=
None
def
download_fn
(
resource
,
root
):
nonlocal
files
archive
,
files
=
self
.
_make_tar
(
root
,
name
=
resource
.
file_name
)
return
archive
resource
=
self
.
DummyResource
(
file_name
=
"folder.tar"
,
preprocess
=
"extract"
,
download_fn
=
download_fn
)
dp
=
resource
.
load
(
tmp_path
)
assert
files
is
not
None
,
"`download_fn()` was never called"
assert
isinstance
(
dp
,
FileOpener
)
actual
=
{
path
:
buffer
.
read
().
decode
()
for
path
,
buffer
in
dp
}
expected
=
{
path
.
replace
(
resource
.
file_name
,
resource
.
file_name
.
split
(
"."
)[
0
]):
content
for
path
,
content
in
files
.
items
()
}
assert
actual
==
expected
def
test_preprocess_only_after_download
(
self
,
tmp_path
):
file
=
self
.
_make_file
(
tmp_path
,
content
=
"_"
)
def
preprocess
(
path
):
raise
AssertionError
(
"`preprocess` was called although the file was already present."
)
resource
=
self
.
DummyResource
(
file_name
=
file
.
name
,
preprocess
=
preprocess
,
)
resource
.
load
(
tmp_path
)
class
TestHttpResource
:
def
test_resolve_to_http
(
self
,
mocker
):
file_name
=
"data.tar"
original_url
=
f
"http://downloads.pytorch.org/
{
file_name
}
"
redirected_url
=
original_url
.
replace
(
"http"
,
"https"
)
sha256_sentinel
=
"sha256_sentinel"
def
preprocess_sentinel
(
path
):
return
path
original_resource
=
HttpResource
(
original_url
,
sha256
=
sha256_sentinel
,
preprocess
=
preprocess_sentinel
,
)
mocker
.
patch
(
"torchvision.prototype.datasets.utils._resource._get_redirect_url"
,
return_value
=
redirected_url
)
redirected_resource
=
original_resource
.
resolve
()
assert
isinstance
(
redirected_resource
,
HttpResource
)
assert
redirected_resource
.
url
==
redirected_url
assert
redirected_resource
.
file_name
==
file_name
assert
redirected_resource
.
sha256
==
sha256_sentinel
assert
redirected_resource
.
_preprocess
is
preprocess_sentinel
def
test_resolve_to_gdrive
(
self
,
mocker
):
file_name
=
"data.tar"
original_url
=
f
"http://downloads.pytorch.org/
{
file_name
}
"
id_sentinel
=
"id-sentinel"
redirected_url
=
f
"https://drive.google.com/file/d/
{
id_sentinel
}
/view"
sha256_sentinel
=
"sha256_sentinel"
def
preprocess_sentinel
(
path
):
return
path
original_resource
=
HttpResource
(
original_url
,
sha256
=
sha256_sentinel
,
preprocess
=
preprocess_sentinel
,
)
mocker
.
patch
(
"torchvision.prototype.datasets.utils._resource._get_redirect_url"
,
return_value
=
redirected_url
)
redirected_resource
=
original_resource
.
resolve
()
assert
isinstance
(
redirected_resource
,
GDriveResource
)
assert
redirected_resource
.
id
==
id_sentinel
assert
redirected_resource
.
file_name
==
file_name
assert
redirected_resource
.
sha256
==
sha256_sentinel
assert
redirected_resource
.
_preprocess
is
preprocess_sentinel
def
test_missing_dependency_error
():
class
DummyDataset
(
Dataset
):
def
__init__
(
self
):
super
().
__init__
(
root
=
"root"
,
dependencies
=
(
"fake_dependency"
,))
def
_resources
(
self
):
pass
def
_datapipe
(
self
,
resource_dps
):
pass
def
__len__
(
self
):
pass
with
pytest
.
raises
(
ModuleNotFoundError
,
match
=
"depends on the third-party package 'fake_dependency'"
):
DummyDataset
()
test/test_prototype_models.py
deleted
100644 → 0
View file @
f44f20cf
import
pytest
import
test_models
as
TM
import
torch
from
common_utils
import
cpu_and_cuda
,
set_rng_seed
from
torchvision.prototype
import
models
@
pytest
.
mark
.
parametrize
(
"model_fn"
,
(
models
.
depth
.
stereo
.
raft_stereo_base
,))
@
pytest
.
mark
.
parametrize
(
"model_mode"
,
(
"standard"
,
"scripted"
))
@
pytest
.
mark
.
parametrize
(
"dev"
,
cpu_and_cuda
())
def
test_raft_stereo
(
model_fn
,
model_mode
,
dev
):
# A simple test to make sure the model can do forward pass and jit scriptable
set_rng_seed
(
0
)
# Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
# get the idea from test_models.test_raft
corr_pyramid
=
models
.
depth
.
stereo
.
raft_stereo
.
CorrPyramid1d
(
num_levels
=
2
)
corr_block
=
models
.
depth
.
stereo
.
raft_stereo
.
CorrBlock1d
(
num_levels
=
2
,
radius
=
2
)
model
=
model_fn
(
corr_pyramid
=
corr_pyramid
,
corr_block
=
corr_block
).
eval
().
to
(
dev
)
if
model_mode
==
"scripted"
:
model
=
torch
.
jit
.
script
(
model
)
img1
=
torch
.
rand
(
1
,
3
,
64
,
64
).
to
(
dev
)
img2
=
torch
.
rand
(
1
,
3
,
64
,
64
).
to
(
dev
)
num_iters
=
3
preds
=
model
(
img1
,
img2
,
num_iters
=
num_iters
)
depth_pred
=
preds
[
-
1
]
assert
len
(
preds
)
==
num_iters
,
"Number of predictions should be the same as model.num_iters"
assert
depth_pred
.
shape
==
torch
.
Size
(
[
1
,
1
,
64
,
64
]
),
f
"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is
{
preds
[
0
].
shape
}
"
# Test against expected file output
TM
.
_assert_expected
(
depth_pred
,
name
=
model_fn
.
__name__
,
atol
=
1e-2
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"model_fn"
,
(
models
.
depth
.
stereo
.
crestereo_base
,))
@
pytest
.
mark
.
parametrize
(
"model_mode"
,
(
"standard"
,
"scripted"
))
@
pytest
.
mark
.
parametrize
(
"dev"
,
cpu_and_cuda
())
def
test_crestereo
(
model_fn
,
model_mode
,
dev
):
set_rng_seed
(
0
)
model
=
model_fn
().
eval
().
to
(
dev
)
if
model_mode
==
"scripted"
:
model
=
torch
.
jit
.
script
(
model
)
img1
=
torch
.
rand
(
1
,
3
,
64
,
64
).
to
(
dev
)
img2
=
torch
.
rand
(
1
,
3
,
64
,
64
).
to
(
dev
)
iterations
=
3
preds
=
model
(
img1
,
img2
,
flow_init
=
None
,
num_iters
=
iterations
)
disparity_pred
=
preds
[
-
1
]
# all the pyramid levels except the highest res make only half the number of iterations
expected_iterations
=
(
iterations
//
2
)
*
(
len
(
model
.
resolutions
)
-
1
)
expected_iterations
+=
iterations
assert
(
len
(
preds
)
==
expected_iterations
),
"Number of predictions should be the number of iterations multiplied by the number of pyramid levels"
assert
disparity_pred
.
shape
==
torch
.
Size
(
[
1
,
2
,
64
,
64
]
),
f
"Predicted disparity should have the same spatial shape as the input. Inputs shape
{
img1
.
shape
[
2
:]
}
, Prediction shape
{
disparity_pred
.
shape
[
2
:]
}
"
assert
all
(
d
.
shape
==
torch
.
Size
([
1
,
2
,
64
,
64
])
for
d
in
preds
),
"All predicted disparities are expected to have the same shape"
# test a backward pass with a dummy loss as well
preds
=
torch
.
stack
(
preds
,
dim
=
0
)
targets
=
torch
.
ones_like
(
preds
,
requires_grad
=
False
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
preds
,
targets
)
try
:
loss
.
backward
()
except
Exception
as
e
:
assert
False
,
f
"Backward pass failed with an unexpected exception:
{
e
.
__class__
.
__name__
}
{
e
}
"
TM
.
_assert_expected
(
disparity_pred
,
name
=
model_fn
.
__name__
,
atol
=
1e-2
,
rtol
=
1e-2
)
test/test_prototype_transforms.py
deleted
100644 → 0
View file @
f44f20cf
import
collections.abc
import
re
import
PIL.Image
import
pytest
import
torch
from
common_utils
import
assert_equal
,
make_bounding_boxes
,
make_detection_masks
,
make_image
,
make_video
from
torchvision.prototype
import
transforms
,
tv_tensors
from
torchvision.transforms.v2._utils
import
check_type
,
is_pure_tensor
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.tv_tensors
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
def
_parse_categories
(
categories
):
if
categories
is
None
:
num_categories
=
int
(
torch
.
randint
(
1
,
11
,
()))
elif
isinstance
(
categories
,
int
):
num_categories
=
categories
categories
=
[
f
"category
{
idx
}
"
for
idx
in
range
(
num_categories
)]
elif
isinstance
(
categories
,
collections
.
abc
.
Sequence
)
and
all
(
isinstance
(
category
,
str
)
for
category
in
categories
):
categories
=
list
(
categories
)
num_categories
=
len
(
categories
)
else
:
raise
pytest
.
UsageError
(
f
"`categories` can either be `None` (default), an integer, or a sequence of strings, "
f
"but got '
{
categories
}
' instead."
)
return
categories
,
num_categories
def
make_label
(
*
,
extra_dims
=
(),
categories
=
10
,
dtype
=
torch
.
int64
,
device
=
"cpu"
):
categories
,
num_categories
=
_parse_categories
(
categories
)
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data
=
torch
.
testing
.
make_tensor
(
extra_dims
,
low
=
0
,
high
=
num_categories
,
dtype
=
torch
.
int64
,
device
=
device
).
to
(
dtype
)
return
tv_tensors
.
Label
(
data
,
categories
=
categories
)
class
TestSimpleCopyPaste
:
def
create_fake_image
(
self
,
mocker
,
image_type
):
if
image_type
==
PIL
.
Image
.
Image
:
return
PIL
.
Image
.
new
(
"RGB"
,
(
32
,
32
),
123
)
return
mocker
.
MagicMock
(
spec
=
image_type
)
def
test__extract_image_targets_assertion
(
self
,
mocker
):
transform
=
transforms
.
SimpleCopyPaste
()
flat_sample
=
[
# images, batch size = 2
self
.
create_fake_image
(
mocker
,
Image
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
tv_tensors
.
Label
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
Mask
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
Mask
),
]
with
pytest
.
raises
(
TypeError
,
match
=
"requires input sample to contain equal sized list of Images"
):
transform
.
_extract_image_targets
(
flat_sample
)
@
pytest
.
mark
.
parametrize
(
"image_type"
,
[
Image
,
PIL
.
Image
.
Image
,
torch
.
Tensor
])
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
tv_tensors
.
Label
,
tv_tensors
.
OneHotLabel
])
def
test__extract_image_targets
(
self
,
image_type
,
label_type
,
mocker
):
transform
=
transforms
.
SimpleCopyPaste
()
flat_sample
=
[
# images, batch size = 2
self
.
create_fake_image
(
mocker
,
image_type
),
self
.
create_fake_image
(
mocker
,
image_type
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
label_type
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
Mask
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
label_type
),
mocker
.
MagicMock
(
spec
=
BoundingBoxes
),
mocker
.
MagicMock
(
spec
=
Mask
),
]
images
,
targets
=
transform
.
_extract_image_targets
(
flat_sample
)
assert
len
(
images
)
==
len
(
targets
)
==
2
if
image_type
==
PIL
.
Image
.
Image
:
torch
.
testing
.
assert_close
(
images
[
0
],
pil_to_tensor
(
flat_sample
[
0
]))
torch
.
testing
.
assert_close
(
images
[
1
],
pil_to_tensor
(
flat_sample
[
1
]))
else
:
assert
images
[
0
]
==
flat_sample
[
0
]
assert
images
[
1
]
==
flat_sample
[
1
]
for
target
in
targets
:
for
key
,
type_
in
[
(
"boxes"
,
BoundingBoxes
),
(
"masks"
,
Mask
),
(
"labels"
,
label_type
),
]:
assert
key
in
target
assert
isinstance
(
target
[
key
],
type_
)
assert
target
[
key
]
in
flat_sample
@
pytest
.
mark
.
parametrize
(
"label_type"
,
[
tv_tensors
.
Label
,
tv_tensors
.
OneHotLabel
])
def
test__copy_paste
(
self
,
label_type
):
image
=
2
*
torch
.
ones
(
3
,
32
,
32
)
masks
=
torch
.
zeros
(
2
,
32
,
32
)
masks
[
0
,
3
:
9
,
2
:
8
]
=
1
masks
[
1
,
20
:
30
,
20
:
30
]
=
1
labels
=
torch
.
tensor
([
1
,
2
])
blending
=
True
resize_interpolation
=
InterpolationMode
.
BILINEAR
antialias
=
None
if
label_type
==
tv_tensors
.
OneHotLabel
:
labels
=
torch
.
nn
.
functional
.
one_hot
(
labels
,
num_classes
=
5
)
target
=
{
"boxes"
:
BoundingBoxes
(
torch
.
tensor
([[
2.0
,
3.0
,
8.0
,
9.0
],
[
20.0
,
20.0
,
30.0
,
30.0
]]),
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
)
),
"masks"
:
Mask
(
masks
),
"labels"
:
label_type
(
labels
),
}
paste_image
=
10
*
torch
.
ones
(
3
,
32
,
32
)
paste_masks
=
torch
.
zeros
(
2
,
32
,
32
)
paste_masks
[
0
,
13
:
19
,
12
:
18
]
=
1
paste_masks
[
1
,
15
:
19
,
1
:
8
]
=
1
paste_labels
=
torch
.
tensor
([
3
,
4
])
if
label_type
==
tv_tensors
.
OneHotLabel
:
paste_labels
=
torch
.
nn
.
functional
.
one_hot
(
paste_labels
,
num_classes
=
5
)
paste_target
=
{
"boxes"
:
BoundingBoxes
(
torch
.
tensor
([[
12.0
,
13.0
,
19.0
,
18.0
],
[
1.0
,
15.0
,
8.0
,
19.0
]]),
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
)
),
"masks"
:
Mask
(
paste_masks
),
"labels"
:
label_type
(
paste_labels
),
}
transform
=
transforms
.
SimpleCopyPaste
()
random_selection
=
torch
.
tensor
([
0
,
1
])
output_image
,
output_target
=
transform
.
_copy_paste
(
image
,
target
,
paste_image
,
paste_target
,
random_selection
,
blending
,
resize_interpolation
,
antialias
)
assert
output_image
.
unique
().
tolist
()
==
[
2
,
10
]
assert
output_target
[
"boxes"
].
shape
==
(
4
,
4
)
torch
.
testing
.
assert_close
(
output_target
[
"boxes"
][:
2
,
:],
target
[
"boxes"
])
torch
.
testing
.
assert_close
(
output_target
[
"boxes"
][
2
:,
:],
paste_target
[
"boxes"
])
expected_labels
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
if
label_type
==
tv_tensors
.
OneHotLabel
:
expected_labels
=
torch
.
nn
.
functional
.
one_hot
(
expected_labels
,
num_classes
=
5
)
torch
.
testing
.
assert_close
(
output_target
[
"labels"
],
label_type
(
expected_labels
))
assert
output_target
[
"masks"
].
shape
==
(
4
,
32
,
32
)
torch
.
testing
.
assert_close
(
output_target
[
"masks"
][:
2
,
:],
target
[
"masks"
])
torch
.
testing
.
assert_close
(
output_target
[
"masks"
][
2
:,
:],
paste_target
[
"masks"
])
class
TestFixedSizeCrop
:
def
test__get_params
(
self
,
mocker
):
crop_size
=
(
7
,
7
)
batch_shape
=
(
10
,)
canvas_size
=
(
11
,
5
)
transform
=
transforms
.
FixedSizeCrop
(
size
=
crop_size
)
flat_inputs
=
[
make_image
(
size
=
canvas_size
,
color_space
=
"RGB"
),
make_bounding_boxes
(
format
=
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
num_boxes
=
batch_shape
[
0
]),
]
params
=
transform
.
_get_params
(
flat_inputs
)
assert
params
[
"needs_crop"
]
assert
params
[
"height"
]
<=
crop_size
[
0
]
assert
params
[
"width"
]
<=
crop_size
[
1
]
assert
(
isinstance
(
params
[
"is_valid"
],
torch
.
Tensor
)
and
params
[
"is_valid"
].
dtype
is
torch
.
bool
and
params
[
"is_valid"
].
shape
==
batch_shape
)
assert
params
[
"needs_pad"
]
assert
any
(
pad
>
0
for
pad
in
params
[
"padding"
])
def
test__transform_culling
(
self
,
mocker
):
batch_size
=
10
canvas_size
=
(
10
,
10
)
is_valid
=
torch
.
randint
(
0
,
2
,
(
batch_size
,),
dtype
=
torch
.
bool
)
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params"
,
return_value
=
dict
(
needs_crop
=
True
,
top
=
0
,
left
=
0
,
height
=
canvas_size
[
0
],
width
=
canvas_size
[
1
],
is_valid
=
is_valid
,
needs_pad
=
False
,
),
)
bounding_boxes
=
make_bounding_boxes
(
format
=
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
num_boxes
=
batch_size
)
masks
=
make_detection_masks
(
size
=
canvas_size
,
num_masks
=
batch_size
)
labels
=
make_label
(
extra_dims
=
(
batch_size
,))
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
))
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
output
=
transform
(
dict
(
bounding_boxes
=
bounding_boxes
,
masks
=
masks
,
labels
=
labels
,
)
)
assert_equal
(
output
[
"bounding_boxes"
],
bounding_boxes
[
is_valid
])
assert_equal
(
output
[
"masks"
],
masks
[
is_valid
])
assert_equal
(
output
[
"labels"
],
labels
[
is_valid
])
def
test__transform_bounding_boxes_clamping
(
self
,
mocker
):
batch_size
=
3
canvas_size
=
(
10
,
10
)
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params"
,
return_value
=
dict
(
needs_crop
=
True
,
top
=
0
,
left
=
0
,
height
=
canvas_size
[
0
],
width
=
canvas_size
[
1
],
is_valid
=
torch
.
full
((
batch_size
,),
fill_value
=
True
),
needs_pad
=
False
,
),
)
bounding_boxes
=
make_bounding_boxes
(
format
=
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
num_boxes
=
batch_size
)
mock
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes"
,
wraps
=
clamp_bounding_boxes
)
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
))
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
transform
(
bounding_boxes
)
mock
.
assert_called_once
()
class
TestLabelToOneHot
:
def
test__transform
(
self
):
categories
=
[
"apple"
,
"pear"
,
"pineapple"
]
labels
=
tv_tensors
.
Label
(
torch
.
tensor
([
0
,
1
,
2
,
1
]),
categories
=
categories
)
transform
=
transforms
.
LabelToOneHot
()
ohe_labels
=
transform
(
labels
)
assert
isinstance
(
ohe_labels
,
tv_tensors
.
OneHotLabel
)
assert
ohe_labels
.
shape
==
(
4
,
3
)
assert
ohe_labels
.
categories
==
labels
.
categories
==
categories
class
TestPermuteDimensions
:
@
pytest
.
mark
.
parametrize
(
(
"dims"
,
"inverse_dims"
),
[
(
{
Image
:
(
2
,
1
,
0
),
Video
:
None
},
{
Image
:
(
2
,
1
,
0
),
Video
:
None
},
),
(
{
Image
:
(
2
,
1
,
0
),
Video
:
(
1
,
2
,
3
,
0
)},
{
Image
:
(
2
,
1
,
0
),
Video
:
(
3
,
0
,
1
,
2
)},
),
],
)
def
test_call
(
self
,
dims
,
inverse_dims
):
sample
=
dict
(
image
=
make_image
(),
bounding_boxes
=
make_bounding_boxes
(
format
=
BoundingBoxFormat
.
XYXY
),
video
=
make_video
(),
str
=
"str"
,
int
=
0
,
)
transform
=
transforms
.
PermuteDimensions
(
dims
)
transformed_sample
=
transform
(
sample
)
for
key
,
value
in
sample
.
items
():
value_type
=
type
(
value
)
transformed_value
=
transformed_sample
[
key
]
if
check_type
(
value
,
(
Image
,
is_pure_tensor
,
Video
)):
if
transform
.
dims
.
get
(
value_type
)
is
not
None
:
assert
transformed_value
.
permute
(
inverse_dims
[
value_type
]).
equal
(
value
)
assert
type
(
transformed_value
)
==
torch
.
Tensor
else
:
assert
transformed_value
is
value
@
pytest
.
mark
.
filterwarnings
(
"error"
)
def
test_plain_tensor_call
(
self
):
tensor
=
torch
.
empty
((
2
,
3
,
4
))
transform
=
transforms
.
PermuteDimensions
(
dims
=
(
1
,
2
,
0
))
assert
transform
(
tensor
).
shape
==
(
3
,
4
,
2
)
@
pytest
.
mark
.
parametrize
(
"other_type"
,
[
Image
,
Video
])
def
test_plain_tensor_warning
(
self
,
other_type
):
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"`torch.Tensor` will *not* be transformed"
)):
transforms
.
PermuteDimensions
(
dims
=
{
torch
.
Tensor
:
(
0
,
1
),
other_type
:
(
1
,
0
)})
class
TestTransposeDimensions
:
@
pytest
.
mark
.
parametrize
(
"dims"
,
[
(
-
1
,
-
2
),
{
Image
:
(
1
,
2
),
Video
:
None
},
],
)
def
test_call
(
self
,
dims
):
sample
=
dict
(
image
=
make_image
(),
bounding_boxes
=
make_bounding_boxes
(
format
=
BoundingBoxFormat
.
XYXY
),
video
=
make_video
(),
str
=
"str"
,
int
=
0
,
)
transform
=
transforms
.
TransposeDimensions
(
dims
)
transformed_sample
=
transform
(
sample
)
for
key
,
value
in
sample
.
items
():
value_type
=
type
(
value
)
transformed_value
=
transformed_sample
[
key
]
transposed_dims
=
transform
.
dims
.
get
(
value_type
)
if
check_type
(
value
,
(
Image
,
is_pure_tensor
,
Video
)):
if
transposed_dims
is
not
None
:
assert
transformed_value
.
transpose
(
*
transposed_dims
).
equal
(
value
)
assert
type
(
transformed_value
)
==
torch
.
Tensor
else
:
assert
transformed_value
is
value
@
pytest
.
mark
.
filterwarnings
(
"error"
)
def
test_plain_tensor_call
(
self
):
tensor
=
torch
.
empty
((
2
,
3
,
4
))
transform
=
transforms
.
TransposeDimensions
(
dims
=
(
0
,
2
))
assert
transform
(
tensor
).
shape
==
(
4
,
3
,
2
)
@
pytest
.
mark
.
parametrize
(
"other_type"
,
[
Image
,
Video
])
def
test_plain_tensor_warning
(
self
,
other_type
):
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"`torch.Tensor` will *not* be transformed"
)):
transforms
.
TransposeDimensions
(
dims
=
{
torch
.
Tensor
:
(
0
,
1
),
other_type
:
(
1
,
0
)})
import
importlib.machinery
import
importlib.util
from
pathlib
import
Path
def
import_transforms_from_references
(
reference
):
HERE
=
Path
(
__file__
).
parent
PROJECT_ROOT
=
HERE
.
parent
loader
=
importlib
.
machinery
.
SourceFileLoader
(
"transforms"
,
str
(
PROJECT_ROOT
/
"references"
/
reference
/
"transforms.py"
)
)
spec
=
importlib
.
util
.
spec_from_loader
(
"transforms"
,
loader
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
loader
.
exec_module
(
module
)
return
module
det_transforms
=
import_transforms_from_references
(
"detection"
)
def
test_fixed_sized_crop_against_detection_reference
():
def
make_tv_tensors
():
size
=
(
600
,
800
)
num_objects
=
22
pil_image
=
to_pil_image
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
num_boxes
=
num_objects
,
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"masks"
:
make_detection_masks
(
size
=
size
,
num_masks
=
num_objects
,
dtype
=
torch
.
long
),
}
yield
(
pil_image
,
target
)
tensor_image
=
torch
.
Tensor
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
num_boxes
=
num_objects
,
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"masks"
:
make_detection_masks
(
size
=
size
,
num_masks
=
num_objects
,
dtype
=
torch
.
long
),
}
yield
(
tensor_image
,
target
)
tv_tensor_image
=
make_image
(
size
=
size
,
color_space
=
"RGB"
)
target
=
{
"boxes"
:
make_bounding_boxes
(
canvas_size
=
size
,
format
=
"XYXY"
,
num_boxes
=
num_objects
,
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"masks"
:
make_detection_masks
(
size
=
size
,
num_masks
=
num_objects
,
dtype
=
torch
.
long
),
}
yield
(
tv_tensor_image
,
target
)
t
=
transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
t_ref
=
det_transforms
.
FixedSizeCrop
((
1024
,
1024
),
fill
=
0
)
for
dp
in
make_tv_tensors
():
# We should use prototype transform first as reference transform performs inplace target update
torch
.
manual_seed
(
12
)
output
=
t
(
dp
)
torch
.
manual_seed
(
12
)
expected_output
=
t_ref
(
*
dp
)
assert_equal
(
expected_output
,
output
)
torchvision/prototype/__init__.py
deleted
100644 → 0
View file @
f44f20cf
from
.
import
models
,
transforms
,
tv_tensors
,
utils
torchvision/prototype/datasets/README.md
deleted
100644 → 0
View file @
f44f20cf
# Status of prototype datasets
These prototype datasets are based on
[
torchdata
](
https://github.com/pytorch/data
)
's datapipes. Torchdata
development
[
is
paused
](
https://github.com/pytorch/data/#torchdata-see-note-below-on-current-status
)
as of July 2023, so we are not actively maintaining this module. There is no
estimated date for a stable release of these datasets.
torchvision/prototype/datasets/__init__.py
deleted
100644 → 0
View file @
f44f20cf
try
:
import
torchdata
except
ModuleNotFoundError
:
raise
ModuleNotFoundError
(
"`torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). "
"You can install it with `pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu"
)
from
None
from
.
import
utils
from
._home
import
home
# Load this last, since some parts depend on the above being loaded first
from
._api
import
list_datasets
,
info
,
load
,
register_info
,
register_dataset
# usort: skip
from
._folder
import
from_data_folder
,
from_image_folder
from
._builtin
import
*
torchvision/prototype/datasets/_api.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
from
torchvision.prototype.datasets
import
home
from
torchvision.prototype.datasets.utils
import
Dataset
from
torchvision.prototype.utils._internal
import
add_suggestion
T
=
TypeVar
(
"T"
)
D
=
TypeVar
(
"D"
,
bound
=
Type
[
Dataset
])
BUILTIN_INFOS
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{}
def
register_info
(
name
:
str
)
->
Callable
[[
Callable
[[],
Dict
[
str
,
Any
]]],
Callable
[[],
Dict
[
str
,
Any
]]]:
def
wrapper
(
fn
:
Callable
[[],
Dict
[
str
,
Any
]])
->
Callable
[[],
Dict
[
str
,
Any
]]:
BUILTIN_INFOS
[
name
]
=
fn
()
return
fn
return
wrapper
BUILTIN_DATASETS
=
{}
def
register_dataset
(
name
:
str
)
->
Callable
[[
D
],
D
]:
def
wrapper
(
dataset_cls
:
D
)
->
D
:
BUILTIN_DATASETS
[
name
]
=
dataset_cls
return
dataset_cls
return
wrapper
def
list_datasets
()
->
List
[
str
]:
return
sorted
(
BUILTIN_DATASETS
.
keys
())
def
find
(
dct
:
Dict
[
str
,
T
],
name
:
str
)
->
T
:
name
=
name
.
lower
()
try
:
return
dct
[
name
]
except
KeyError
as
error
:
raise
ValueError
(
add_suggestion
(
f
"Unknown dataset '
{
name
}
'."
,
word
=
name
,
possibilities
=
dct
.
keys
(),
alternative_hint
=
lambda
_
:
(
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
),
)
)
from
error
def
info
(
name
:
str
)
->
Dict
[
str
,
Any
]:
return
find
(
BUILTIN_INFOS
,
name
)
def
load
(
name
:
str
,
*
,
root
:
Optional
[
Union
[
str
,
pathlib
.
Path
]]
=
None
,
**
config
:
Any
)
->
Dataset
:
dataset_cls
=
find
(
BUILTIN_DATASETS
,
name
)
if
root
is
None
:
root
=
pathlib
.
Path
(
home
())
/
name
return
dataset_cls
(
root
,
**
config
)
torchvision/prototype/datasets/_builtin/README.md
deleted
100644 → 0
View file @
f44f20cf
# How to add new built-in prototype datasets
As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means
that this document will also change a lot.
If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented
there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out.
Finally,
`from torchvision.prototype import datasets`
is implied below.
## Implementation
Before we start with the actual implementation, you should create a module in
`torchvision/prototype/datasets/_builtin`
that hints at the dataset you are going to add. For example
`caltech.py`
for
`caltech101`
and
`caltech256`
. In that
module create a class that inherits from
`datasets.utils.Dataset`
and overwrites four methods that will be discussed in
detail below:
```
python
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Tuple
,
Union
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchvision.prototype.datasets.utils
import
Dataset
,
OnlineResource
from
.._api
import
register_dataset
,
register_info
NAME
=
"my-dataset"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
...
)
@
register_dataset
(
NAME
)
class
MyDataset
(
Dataset
):
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
...,
skip_integrity_check
:
bool
=
False
)
->
None
:
...
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
...
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
[
Tuple
[
str
,
BinaryIO
]]])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
...
def
__len__
(
self
)
->
int
:
...
```
In addition to the dataset, you also need to implement an
`_info()`
function that takes no arguments and returns a
dictionary of static information. The most common use case is to provide human-readable categories.
[
See below
](
#how-do-i-handle-a-dataset-that-defines-many-categories
)
how to handle cases with many categories.
Finally, both the dataset class and the info function need to be registered on the API with the respective decorators.
With that they are loadable through
`datasets.load("my-dataset")`
and
`datasets.info("my-dataset")`
, respectively.
### `__init__(self, root, *, ..., skip_integrity_check = False)`
Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the
base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as
setting them as instance attributes has to happen before the call of
`super().__init__(...)`
, because that will invoke
the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with
an underscore.
If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base
class constructor, e.g.
`super().__init__(..., dependencies=("scipy",))`
. Their availability will be automatically
checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to
avoid missing dependencies at import time.
### `_resources(self)`
Returns
`List[datasets.utils.OnlineResource]`
of all the files that need to be present locally before the dataset can be
build. The download will happen automatically.
Currently, the following
`OnlineResource`
's are supported:
-
`HttpResource`
: Used for files that are directly exposed through HTTP(s) and only requires the URL.
-
`GDriveResource`
: Used for files that are hosted on GDrive and requires the GDrive ID as well as the
`file_name`
.
-
`ManualDownloadResource`
: Used files are not publicly accessible and requires instructions how to download them
manually. If the file does not exist, an error will be raised with the supplied instructions.
-
`KaggleDownloadResource`
: Used for files that are available on Kaggle. This inherits from
`ManualDownloadResource`
.
Although optional in general, all resources used in the built-in datasets should comprise
[
SHA256
](
https://en.wikipedia.org/wiki/SHA-2
)
checksum for security. It will be automatically checked after the
download. You can compute the checksum with system utilities e.g
`sha256-sum`
, or this snippet:
```
python
import
hashlib
def
sha256sum
(
path
,
chunk_size
=
1024
*
1024
):
checksum
=
hashlib
.
sha256
()
with
open
(
path
,
"rb"
)
as
f
:
while
chunk
:
=
f
.
read
(
chunk_size
):
checksum
.
update
(
chunk
)
print
(
checksum
.
hexdigest
())
```
### `_datapipe(self, resource_dps)`
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
to the current stable datasets is that everything is performed through
`IterDataPipe`
's. From the perspective of someone
that is working with them rather than on them,
`IterDataPipe`
's behave just as generators, i.e. you can't do anything
with them besides iterating.
Of course, there are some common building blocks that should suffice in 95% of the cases. The most used are:
-
`Mapper`
: Apply a callable to every item in the datapipe.
-
`Filter`
: Keep only items that satisfy a condition.
-
`Demultiplexer`
: Split a datapipe into multiple ones.
-
`IterKeyZipper`
: Merge two datapipes into one.
All of them can be imported
`from torchdata.datapipes.iter`
. In addition, use
`functools.partial`
in case a callable
needs extra arguments. If the provided
`IterDataPipe`
's are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example.
`_datapipe()`
receives
`resource_dps`
, which is a list of datapipes that has a 1-to-1 correspondence with the return
value of
`_resources()`
. In case of archives with regular suffixes (
`.tar`
,
`.zip`
, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one
of such tuples for the file specified by the resource.
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g.
`IterKeyZipper`
and
`Grouper`
. There are two issues with that:
1.
If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2.
This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions
`hint_shuffling`
and
`hint_sharding`
. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding
should take place, but are no-ops by default. They can be imported from
`torchvision.prototype.datasets.utils._internal`
and are required in each dataset.
`hint_shuffling`
has to be placed before
`hint_sharding`
.
Finally, each item in the final datapipe should be a dictionary with
`str`
keys. There is no standardization of the
names (yet!).
### `__len__`
This returns an integer denoting the number of samples that can be drawn from the dataset. Please use
[
underscores
](
https://peps.python.org/pep-0515/
)
after every three digits starting from the right to enhance the
readability. For example,
`1_281_167`
vs.
`1281167`
.
If there are only two different numbers, a simple
`if`
/
`else`
is fine:
```
py
def
__len__
(
self
):
return
12_345
if
self
.
_split
==
"train"
else
6_789
```
If there are more options, using a dictionary usually is the most readable option:
```
py
def
__len__
(
self
):
return
{
"train"
:
3
,
"val"
:
2
,
"test"
:
1
,
}[
self
.
_split
]
```
If the number of samples depends on more than one parameter, you can use tuples as dictionary keys:
```
py
def
__len__
(
self
):
return
{
(
"train"
,
"bar"
):
4
,
(
"train"
,
"baz"
):
3
,
(
"test"
,
"bar"
):
2
,
(
"test"
,
"baz"
):
1
,
}[(
self
.
_split
,
self
.
_foo
)]
```
The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the
development process. Since it is an
`@abstractmethod`
you still have to implement it from the start. The canonical way
is to define a dummy method like
```
py
def
__len__
(
self
):
return
1
```
and only fill it with the correct data if the implementation is otherwise finished.
[
See below
](
#how-do-i-compute-the-number-of-samples
)
for a possible way to compute the number of samples.
## Tests
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
This mock-up should resemble the original data as close as necessary, while containing only few examples.
To do this, add a new function in
[
`test/builtin_dataset_mocks.py`
](
../../../../test/builtin_dataset_mocks.py
)
with the
same name as you have used in
`@register_info`
and
`@register_dataset`
. This function is called "mock data function".
Decorate it with
`@register_mock(configs=[dict(...), ...])`
. Each dictionary denotes one configuration that the dataset
will be loaded with, e.g.
`datasets.load("my-dataset", **config)`
. For the most common case of a product of all options,
you can use the
`combinations_grid()`
helper function, e.g.
`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`
.
In case the name of the dataset includes hyphens
`-`
, replace them with underscores
`_`
in the function name and pass
the
`name`
parameter to
`@register_mock`
```
py
# this is defined in torchvision/prototype/datasets/_builtin
@
register_dataset
(
"my-dataset"
)
class
MyDataset
(
Dataset
):
...
@
register_mock
(
name
=
"my-dataset"
,
configs
=
...)
def
my_dataset
(
root
,
config
):
...
```
The mock data function receives two arguments:
-
`root`
: A
[
`pathlib.Path`
](
https://docs.python.org/3/library/pathlib.html#pathlib.Path
)
of a folder, in which the data
needs to be placed.
-
`config`
: The configuration to generate the data for. This is one of the dictionaries defined in
`@register_mock(configs=...)`
The function should generate all files that are needed for the current
`config`
. Each file should be complete, e.g. if
the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of
the current
`config`
. Although this seems odd at first, this is important. Consider the following original data setup:
```
root
├── test
│ ├── test_image0.jpg
│ ...
└── train
├── train_image0.jpg
...
```
For map-style datasets (like the one currently in
`torchvision.datasets`
), one explicitly selects the files they want to
load. For example, something like
`(root / split).iterdir()`
works fine even if only the specific split folder is
present. With iterable-style datasets though, we get something like
`root.iterdir()`
from
`resource_dps`
in
`_datapipe()`
and need to manually
`Filter`
it to only keep the files we want. If we would only generate the data for
the current
`config`
, the test would also pass if the dataset is missing the filtering, but would fail on the real data.
For datasets that are ported from the old API, we already have some mock data in
[
`test/test_datasets.py`
](
../../../../test/test_datasets.py
)
. You can find the test case corresponding test case there
and have a look at the
`inject_fake_data`
function. There are a few differences though:
-
`tmp_dir`
corresponds to
`root`
, but is a
`str`
rather than a
[
`pathlib.Path`
](
https://docs.python.org/3/library/pathlib.html#pathlib.Path
)
. Thus, you often see something like
`folder = pathlib.Path(tmp_dir)`
. This is not needed.
-
The data generated by
`inject_fake_data`
was supposed to be in an extracted state. This is no longer the case for the
new mock-ups. Thus, you need to use helper functions like
`make_zip`
or
`make_tar`
to actually generate the files
specified in the dataset.
-
As explained in the paragraph above, the generated data is often "incomplete" and only valid for given the config.
Make sure you follow the instructions above.
The function should return an integer indicating the number of samples in the dataset for the current
`config`
.
Preferably, this number should be different for different
`config`
's to have more confidence in the dataset
implementation.
Finally, you can run the tests with
`pytest test/test_prototype_builtin_datasets.py -k {name}`
.
## FAQ
### How do I start?
Get the skeleton of your dataset class ready with all 4 methods. For
`_datapipe()`
, you can just do
`return resources_dp[0]`
to get started. Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`
: this will automatically register the dataset, and it will be
instantiable via
`datasets.load("mydataset")`
. On a separate script, try something like
```
py
from
torchvision.prototype
import
datasets
dataset
=
datasets
.
load
(
"mydataset"
)
for
sample
in
dataset
:
print
(
sample
)
# this is the content of an item in datapipe returned by _datapipe()
break
# Or you can also inspect the sample in a debugger
```
This will give you an idea of what the first datapipe in
`resources_dp`
contains. You can also do that with
`resources_dp[1]`
or
`resources_dp[2]`
(etc.) if they exist. Then follow the instructions above to manipulate these
datapipes and return the appropriate dictionary format.
### How do I handle a dataset that defines many categories?
As a rule of thumb,
`categories`
in the info dictionary should only be set manually for ten categories or fewer. If more
categories are needed, you can add a
`$NAME.categories`
file to the
`_builtin`
folder in which each line specifies a
category. To load such a file, use the
`from torchvision.prototype.datasets.utils._internal import read_categories_file`
function and pass it
`$NAME`
.
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
each folder denotes the name of the category, the dataset can overwrite the
`_generate_categories`
method. The method
should return a sequence of strings representing the category names. In the method body, you'll have to manually load
the resources, e.g.
```
py
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
```
Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes
sense and afterwards materialize the data with
`next(iter(dp))`
or
`list(dp)`
and proceed with that.
To generate the
`$NAME.categories`
file, run
`python -m torchvision.prototype.datasets.generate_category_files $NAME`
.
### What if a resource file forms an I/O bottleneck?
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
the performance hit becomes significant, the archives can still be preprocessed.
`OnlineResource`
accepts the
`preprocess`
parameter that can be a
`Callable[[pathlib.Path], pathlib.Path]`
where the input points to the file to be
preprocessed and the return value should be the result of the preprocessing to load. For convenience,
`preprocess`
also
accepts
`"decompress"`
and
`"extract"`
to handle these common scenarios.
### How do I compute the number of samples?
Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way
than to iterate over the dataset and count the number of samples:
```
py
import
itertools
from
torchvision.prototype
import
datasets
def
combinations_grid
(
**
kwargs
):
return
[
dict
(
zip
(
kwargs
.
keys
(),
values
))
for
values
in
itertools
.
product
(
*
kwargs
.
values
())]
# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there
configs
=
combinations_grid
(
split
=
(
"train"
,
"test"
),
foo
=
(
"bar"
,
"baz"
))
for
config
in
configs
:
dataset
=
datasets
.
load
(
"my-dataset"
,
**
config
)
num_samples
=
0
for
_
in
dataset
:
num_samples
+=
1
print
(
", "
.
join
(
f
"
{
key
}
=
{
value
}
"
for
key
,
value
in
config
.
items
()),
num_samples
)
```
To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation
files.
torchvision/prototype/datasets/_builtin/__init__.py
deleted
100644 → 0
View file @
f44f20cf
from
.caltech
import
Caltech101
,
Caltech256
from
.celeba
import
CelebA
from
.cifar
import
Cifar10
,
Cifar100
from
.clevr
import
CLEVR
from
.coco
import
Coco
from
.country211
import
Country211
from
.cub200
import
CUB200
from
.dtd
import
DTD
from
.eurosat
import
EuroSAT
from
.fer2013
import
FER2013
from
.food101
import
Food101
from
.gtsrb
import
GTSRB
from
.imagenet
import
ImageNet
from
.mnist
import
EMNIST
,
FashionMNIST
,
KMNIST
,
MNIST
,
QMNIST
from
.oxford_iiit_pet
import
OxfordIIITPet
from
.pcam
import
PCAM
from
.sbd
import
SBD
from
.semeion
import
SEMEION
from
.stanford_cars
import
StanfordCars
from
.svhn
import
SVHN
from
.usps
import
USPS
from
.voc
import
VOC
torchvision/prototype/datasets/_builtin/caltech.py
deleted
100644 → 0
View file @
f44f20cf
import
pathlib
import
re
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
read_categories_file
,
read_mat
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
@
register_info
(
"caltech101"
)
def
_caltech101_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
"caltech101"
))
@
register_dataset
(
"caltech101"
)
class
Caltech101
(
Dataset
):
"""
- **homepage**: https://data.caltech.edu/records/20086
- **dependencies**:
- <scipy `https://scipy.org/`>_
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_categories
=
_caltech101_info
()[
"categories"
]
super
().
__init__
(
root
,
dependencies
=
(
"scipy"
,),
skip_integrity_check
=
skip_integrity_check
,
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
images
=
GDriveResource
(
"137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp"
,
file_name
=
"101_ObjectCategories.tar.gz"
,
sha256
=
"af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926"
,
preprocess
=
"decompress"
,
)
anns
=
GDriveResource
(
"175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m"
,
file_name
=
"Annotations.tar"
,
sha256
=
"1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8"
,
)
return
[
images
,
anns
]
_IMAGES_NAME_PATTERN
=
re
.
compile
(
r
"image_(?P<id>\d+)[.]jpg"
)
_ANNS_NAME_PATTERN
=
re
.
compile
(
r
"annotation_(?P<id>\d+)[.]mat"
)
_ANNS_CATEGORY_MAP
=
{
"Faces_2"
:
"Faces"
,
"Faces_3"
:
"Faces_easy"
,
"Motorbikes_16"
:
"Motorbikes"
,
"Airplanes_Side_2"
:
"airplanes"
,
}
def
_is_not_background_image
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
parent
.
name
!=
"BACKGROUND_Google"
def
_is_ann
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
bool
(
self
.
_ANNS_NAME_PATTERN
.
match
(
path
.
name
))
def
_images_key_fn
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Tuple
[
str
,
str
]:
path
=
pathlib
.
Path
(
data
[
0
])
category
=
path
.
parent
.
name
id
=
self
.
_IMAGES_NAME_PATTERN
.
match
(
path
.
name
).
group
(
"id"
)
# type: ignore[union-attr]
return
category
,
id
def
_anns_key_fn
(
self
,
data
:
Tuple
[
str
,
Any
])
->
Tuple
[
str
,
str
]:
path
=
pathlib
.
Path
(
data
[
0
])
category
=
path
.
parent
.
name
if
category
in
self
.
_ANNS_CATEGORY_MAP
:
category
=
self
.
_ANNS_CATEGORY_MAP
[
category
]
id
=
self
.
_ANNS_NAME_PATTERN
.
match
(
path
.
name
).
group
(
"id"
)
# type: ignore[union-attr]
return
category
,
id
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
str
,
str
],
Tuple
[
Tuple
[
str
,
BinaryIO
],
Tuple
[
str
,
BinaryIO
]]]
)
->
Dict
[
str
,
Any
]:
key
,
(
image_data
,
ann_data
)
=
data
category
,
_
=
key
image_path
,
image_buffer
=
image_data
ann_path
,
ann_buffer
=
ann_data
image
=
EncodedImage
.
from_file
(
image_buffer
)
ann
=
read_mat
(
ann_buffer
)
return
dict
(
label
=
Label
.
from_category
(
category
,
categories
=
self
.
_categories
),
image_path
=
image_path
,
image
=
image
,
ann_path
=
ann_path
,
bounding_boxes
=
BoundingBoxes
(
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
format
=
"xyxy"
,
spatial_size
=
image
.
spatial_size
,
),
contour
=
torch
.
as_tensor
(
ann
[
"obj_contour"
].
T
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
images_dp
,
anns_dp
=
resource_dps
images_dp
=
Filter
(
images_dp
,
self
.
_is_not_background_image
)
images_dp
=
hint_shuffling
(
images_dp
)
images_dp
=
hint_sharding
(
images_dp
)
anns_dp
=
Filter
(
anns_dp
,
self
.
_is_ann
)
dp
=
IterKeyZipper
(
images_dp
,
anns_dp
,
key_fn
=
self
.
_images_key_fn
,
ref_key_fn
=
self
.
_anns_key_fn
,
buffer_size
=
INFINITE_BUFFER_SIZE
,
keep_key
=
True
,
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
8677
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
self
.
_is_not_background_image
)
return
sorted
({
pathlib
.
Path
(
path
).
parent
.
name
for
path
,
_
in
dp
})
@
register_info
(
"caltech256"
)
def
_caltech256_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
"caltech256"
))
@
register_dataset
(
"caltech256"
)
class
Caltech256
(
Dataset
):
"""
- **homepage**: https://data.caltech.edu/records/20087
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_categories
=
_caltech256_info
()[
"categories"
]
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
return
[
GDriveResource
(
"1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK"
,
file_name
=
"256_ObjectCategories.tar"
,
sha256
=
"08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e"
,
)
]
def
_is_not_rogue_file
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
!=
"RENAME2"
def
_prepare_sample
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Dict
[
str
,
Any
]:
path
,
buffer
=
data
return
dict
(
path
=
path
,
image
=
EncodedImage
.
from_file
(
buffer
),
label
=
Label
(
int
(
pathlib
.
Path
(
path
).
parent
.
name
.
split
(
"."
,
1
)[
0
])
-
1
,
categories
=
self
.
_categories
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
dp
=
resource_dps
[
0
]
dp
=
Filter
(
dp
,
self
.
_is_not_rogue_file
)
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
30607
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
dir_names
=
{
pathlib
.
Path
(
path
).
parent
.
name
for
path
,
_
in
dp
}
return
[
name
.
split
(
"."
)[
1
]
for
name
in
sorted
(
dir_names
)]
torchvision/prototype/datasets/_builtin/caltech101.categories
deleted
100644 → 0
View file @
f44f20cf
Faces
Faces_easy
Leopards
Motorbikes
accordion
airplanes
anchor
ant
barrel
bass
beaver
binocular
bonsai
brain
brontosaurus
buddha
butterfly
camera
cannon
car_side
ceiling_fan
cellphone
chair
chandelier
cougar_body
cougar_face
crab
crayfish
crocodile
crocodile_head
cup
dalmatian
dollar_bill
dolphin
dragonfly
electric_guitar
elephant
emu
euphonium
ewer
ferry
flamingo
flamingo_head
garfield
gerenuk
gramophone
grand_piano
hawksbill
headphone
hedgehog
helicopter
ibis
inline_skate
joshua_tree
kangaroo
ketch
lamp
laptop
llama
lobster
lotus
mandolin
mayfly
menorah
metronome
minaret
nautilus
octopus
okapi
pagoda
panda
pigeon
pizza
platypus
pyramid
revolver
rhino
rooster
saxophone
schooner
scissors
scorpion
sea_horse
snoopy
soccer_ball
stapler
starfish
stegosaurus
stop_sign
strawberry
sunflower
tick
trilobite
umbrella
watch
water_lilly
wheelchair
wild_cat
windsor_chair
wrench
yin_yang
torchvision/prototype/datasets/_builtin/caltech256.categories
deleted
100644 → 0
View file @
f44f20cf
ak47
american-flag
backpack
baseball-bat
baseball-glove
basketball-hoop
bat
bathtub
bear
beer-mug
billiards
binoculars
birdbath
blimp
bonsai-101
boom-box
bowling-ball
bowling-pin
boxing-glove
brain-101
breadmaker
buddha-101
bulldozer
butterfly
cactus
cake
calculator
camel
cannon
canoe
car-tire
cartman
cd
centipede
cereal-box
chandelier-101
chess-board
chimp
chopsticks
cockroach
coffee-mug
coffin
coin
comet
computer-keyboard
computer-monitor
computer-mouse
conch
cormorant
covered-wagon
cowboy-hat
crab-101
desk-globe
diamond-ring
dice
dog
dolphin-101
doorknob
drinking-straw
duck
dumb-bell
eiffel-tower
electric-guitar-101
elephant-101
elk
ewer-101
eyeglasses
fern
fighter-jet
fire-extinguisher
fire-hydrant
fire-truck
fireworks
flashlight
floppy-disk
football-helmet
french-horn
fried-egg
frisbee
frog
frying-pan
galaxy
gas-pump
giraffe
goat
golden-gate-bridge
goldfish
golf-ball
goose
gorilla
grand-piano-101
grapes
grasshopper
guitar-pick
hamburger
hammock
harmonica
harp
harpsichord
hawksbill-101
head-phones
helicopter-101
hibiscus
homer-simpson
horse
horseshoe-crab
hot-air-balloon
hot-dog
hot-tub
hourglass
house-fly
human-skeleton
hummingbird
ibis-101
ice-cream-cone
iguana
ipod
iris
jesus-christ
joy-stick
kangaroo-101
kayak
ketch-101
killer-whale
knife
ladder
laptop-101
lathe
leopards-101
license-plate
lightbulb
light-house
lightning
llama-101
mailbox
mandolin
mars
mattress
megaphone
menorah-101
microscope
microwave
minaret
minotaur
motorbikes-101
mountain-bike
mushroom
mussels
necktie
octopus
ostrich
owl
palm-pilot
palm-tree
paperclip
paper-shredder
pci-card
penguin
people
pez-dispenser
photocopier
picnic-table
playing-card
porcupine
pram
praying-mantis
pyramid
raccoon
radio-telescope
rainbow
refrigerator
revolver-101
rifle
rotary-phone
roulette-wheel
saddle
saturn
school-bus
scorpion-101
screwdriver
segway
self-propelled-lawn-mower
sextant
sheet-music
skateboard
skunk
skyscraper
smokestack
snail
snake
sneaker
snowmobile
soccer-ball
socks
soda-can
spaghetti
speed-boat
spider
spoon
stained-glass
starfish-101
steering-wheel
stirrups
sunflower-101
superman
sushi
swan
swiss-army-knife
sword
syringe
tambourine
teapot
teddy-bear
teepee
telephone-box
tennis-ball
tennis-court
tennis-racket
theodolite
toaster
tomato
tombstone
top-hat
touring-bike
tower-pisa
traffic-light
treadmill
triceratops
tricycle
trilobite-101
tripod
t-shirt
tuning-fork
tweezer
umbrella-101
unicorn
vcr
video-projector
washing-machine
watch-101
waterfall
watermelon
welding-mask
wheelbarrow
windmill
wine-bottle
xylophone
yarmulke
yo-yo
zebra
airplanes-101
car-side-101
faces-easy-101
greyhound
tennis-shoes
toad
clutter
torchvision/prototype/datasets/_builtin/celeba.py
deleted
100644 → 0
View file @
f44f20cf
import
csv
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
hint_sharding
,
hint_shuffling
,
INFINITE_BUFFER_SIZE
,
path_accessor
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
BoundingBoxes
from
.._api
import
register_dataset
,
register_info
csv
.
register_dialect
(
"celeba"
,
delimiter
=
" "
,
skipinitialspace
=
True
)
class
CelebACSVParser
(
IterDataPipe
[
Tuple
[
str
,
Dict
[
str
,
str
]]]):
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Tuple
[
Any
,
BinaryIO
]],
*
,
fieldnames
:
Optional
[
Sequence
[
str
]]
=
None
,
)
->
None
:
self
.
datapipe
=
datapipe
self
.
fieldnames
=
fieldnames
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
str
,
Dict
[
str
,
str
]]]:
for
_
,
file
in
self
.
datapipe
:
try
:
lines
=
(
line
.
decode
()
for
line
in
file
)
if
self
.
fieldnames
:
fieldnames
=
self
.
fieldnames
else
:
# The first row is skipped, because it only contains the number of samples
next
(
lines
)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames
=
[
name
for
name
in
next
(
csv
.
reader
([
next
(
lines
)],
dialect
=
"celeba"
))
if
name
]
# Some files do not include a label for the image ID column
if
fieldnames
[
0
]
!=
"image_id"
:
fieldnames
.
insert
(
0
,
"image_id"
)
for
line
in
csv
.
DictReader
(
lines
,
fieldnames
=
fieldnames
,
dialect
=
"celeba"
):
yield
line
.
pop
(
"image_id"
),
line
finally
:
file
.
close
()
NAME
=
"celeba"
@
register_info
(
NAME
)
def
_info
()
->
Dict
[
str
,
Any
]:
return
dict
()
@
register_dataset
(
NAME
)
class
CelebA
(
Dataset
):
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"val"
,
"test"
))
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
def
_resources
(
self
)
->
List
[
OnlineResource
]:
splits
=
GDriveResource
(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk"
,
sha256
=
"fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7"
,
file_name
=
"list_eval_partition.txt"
,
)
images
=
GDriveResource
(
"0B7EVK8r0v71pZjFTYXZWM3FlRnM"
,
sha256
=
"46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74"
,
file_name
=
"img_align_celeba.zip"
,
)
identities
=
GDriveResource
(
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS"
,
sha256
=
"c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0"
,
file_name
=
"identity_CelebA.txt"
,
)
attributes
=
GDriveResource
(
"0B7EVK8r0v71pblRyaVFSWGxPY0U"
,
sha256
=
"f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0"
,
file_name
=
"list_attr_celeba.txt"
,
)
bounding_boxes
=
GDriveResource
(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0"
,
sha256
=
"7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b"
,
file_name
=
"list_bbox_celeba.txt"
,
)
landmarks
=
GDriveResource
(
"0B7EVK8r0v71pd0FJY3Blby1HUTQ"
,
sha256
=
"6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b"
,
file_name
=
"list_landmarks_align_celeba.txt"
,
)
return
[
splits
,
images
,
identities
,
attributes
,
bounding_boxes
,
landmarks
]
def
_filter_split
(
self
,
data
:
Tuple
[
str
,
Dict
[
str
,
str
]])
->
bool
:
split_id
=
{
"train"
:
"0"
,
"val"
:
"1"
,
"test"
:
"2"
,
}[
self
.
_split
]
return
data
[
1
][
"split_id"
]
==
split_id
def
_prepare_sample
(
self
,
data
:
Tuple
[
Tuple
[
str
,
Tuple
[
Tuple
[
str
,
List
[
str
]],
Tuple
[
str
,
BinaryIO
]]],
Tuple
[
Tuple
[
str
,
Dict
[
str
,
str
]],
Tuple
[
str
,
Dict
[
str
,
str
]],
Tuple
[
str
,
Dict
[
str
,
str
]],
Tuple
[
str
,
Dict
[
str
,
str
]],
],
],
)
->
Dict
[
str
,
Any
]:
split_and_image_data
,
ann_data
=
data
_
,
(
_
,
image_data
)
=
split_and_image_data
path
,
buffer
=
image_data
image
=
EncodedImage
.
from_file
(
buffer
)
(
_
,
identity
),
(
_
,
attributes
),
(
_
,
bounding_boxes
),
(
_
,
landmarks
)
=
ann_data
return
dict
(
path
=
path
,
image
=
image
,
identity
=
Label
(
int
(
identity
[
"identity"
])),
attributes
=
{
attr
:
value
==
"1"
for
attr
,
value
in
attributes
.
items
()},
bounding_boxes
=
BoundingBoxes
(
[
int
(
bounding_boxes
[
key
])
for
key
in
(
"x_1"
,
"y_1"
,
"width"
,
"height"
)],
format
=
"xywh"
,
spatial_size
=
image
.
spatial_size
,
),
landmarks
=
{
landmark
:
torch
.
tensor
((
int
(
landmarks
[
f
"
{
landmark
}
_x"
]),
int
(
landmarks
[
f
"
{
landmark
}
_y"
])))
for
landmark
in
{
key
[:
-
2
]
for
key
in
landmarks
.
keys
()}
},
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
splits_dp
,
images_dp
,
identities_dp
,
attributes_dp
,
bounding_boxes_dp
,
landmarks_dp
=
resource_dps
splits_dp
=
CelebACSVParser
(
splits_dp
,
fieldnames
=
(
"image_id"
,
"split_id"
))
splits_dp
=
Filter
(
splits_dp
,
self
.
_filter_split
)
splits_dp
=
hint_shuffling
(
splits_dp
)
splits_dp
=
hint_sharding
(
splits_dp
)
anns_dp
=
Zipper
(
*
[
CelebACSVParser
(
dp
,
fieldnames
=
fieldnames
)
for
dp
,
fieldnames
in
(
(
identities_dp
,
(
"image_id"
,
"identity"
)),
(
attributes_dp
,
None
),
(
bounding_boxes_dp
,
None
),
(
landmarks_dp
,
None
),
)
]
)
dp
=
IterKeyZipper
(
splits_dp
,
images_dp
,
key_fn
=
getitem
(
0
),
ref_key_fn
=
path_accessor
(
"name"
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
keep_key
=
True
,
)
dp
=
IterKeyZipper
(
dp
,
anns_dp
,
key_fn
=
getitem
(
0
),
ref_key_fn
=
getitem
(
0
,
0
),
buffer_size
=
INFINITE_BUFFER_SIZE
,
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
{
"train"
:
162_770
,
"val"
:
19_867
,
"test"
:
19_962
,
}[
self
.
_split
]
torchvision/prototype/datasets/_builtin/cifar.py
deleted
100644 → 0
View file @
f44f20cf
import
abc
import
io
import
pathlib
import
pickle
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
Mapper
from
torchvision.prototype.datasets.utils
import
Dataset
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_shuffling
,
path_comparator
,
read_categories_file
,
)
from
torchvision.prototype.tv_tensors
import
Label
from
torchvision.tv_tensors
import
Image
from
.._api
import
register_dataset
,
register_info
class
CifarFileReader
(
IterDataPipe
[
Tuple
[
np
.
ndarray
,
int
]]):
def
__init__
(
self
,
datapipe
:
IterDataPipe
[
Dict
[
str
,
Any
]],
*
,
labels_key
:
str
)
->
None
:
self
.
datapipe
=
datapipe
self
.
labels_key
=
labels_key
def
__iter__
(
self
)
->
Iterator
[
Tuple
[
np
.
ndarray
,
int
]]:
for
mapping
in
self
.
datapipe
:
image_arrays
=
mapping
[
"data"
].
reshape
((
-
1
,
3
,
32
,
32
))
category_idcs
=
mapping
[
self
.
labels_key
]
yield
from
iter
(
zip
(
image_arrays
,
category_idcs
))
class
_CifarBase
(
Dataset
):
_FILE_NAME
:
str
_SHA256
:
str
_LABELS_KEY
:
str
_META_FILE_NAME
:
str
_CATEGORIES_KEY
:
str
_categories
:
List
[
str
]
def
__init__
(
self
,
root
:
Union
[
str
,
pathlib
.
Path
],
*
,
split
:
str
=
"train"
,
skip_integrity_check
:
bool
=
False
,
)
->
None
:
self
.
_split
=
self
.
_verify_str_arg
(
split
,
"split"
,
(
"train"
,
"test"
))
super
().
__init__
(
root
,
skip_integrity_check
=
skip_integrity_check
)
@
abc
.
abstractmethod
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
BinaryIO
])
->
Optional
[
int
]:
pass
def
_resources
(
self
)
->
List
[
OnlineResource
]:
return
[
HttpResource
(
f
"https://www.cs.toronto.edu/~kriz/
{
self
.
_FILE_NAME
}
"
,
sha256
=
self
.
_SHA256
,
)
]
def
_unpickle
(
self
,
data
:
Tuple
[
str
,
io
.
BytesIO
])
->
Dict
[
str
,
Any
]:
_
,
file
=
data
content
=
cast
(
Dict
[
str
,
Any
],
pickle
.
load
(
file
,
encoding
=
"latin1"
))
file
.
close
()
return
content
def
_prepare_sample
(
self
,
data
:
Tuple
[
np
.
ndarray
,
int
])
->
Dict
[
str
,
Any
]:
image_array
,
category_idx
=
data
return
dict
(
image
=
Image
(
image_array
),
label
=
Label
(
category_idx
,
categories
=
self
.
_categories
),
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
dp
=
resource_dps
[
0
]
dp
=
Filter
(
dp
,
self
.
_is_data_file
)
dp
=
Mapper
(
dp
,
self
.
_unpickle
)
dp
=
CifarFileReader
(
dp
,
labels_key
=
self
.
_LABELS_KEY
)
dp
=
hint_shuffling
(
dp
)
dp
=
hint_sharding
(
dp
)
return
Mapper
(
dp
,
self
.
_prepare_sample
)
def
__len__
(
self
)
->
int
:
return
50_000
if
self
.
_split
==
"train"
else
10_000
def
_generate_categories
(
self
)
->
List
[
str
]:
resources
=
self
.
_resources
()
dp
=
resources
[
0
].
load
(
self
.
_root
)
dp
=
Filter
(
dp
,
path_comparator
(
"name"
,
self
.
_META_FILE_NAME
))
dp
=
Mapper
(
dp
,
self
.
_unpickle
)
return
cast
(
List
[
str
],
next
(
iter
(
dp
))[
self
.
_CATEGORIES_KEY
])
@
register_info
(
"cifar10"
)
def
_cifar10_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
"cifar10"
))
@
register_dataset
(
"cifar10"
)
class
Cifar10
(
_CifarBase
):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME
=
"cifar-10-python.tar.gz"
_SHA256
=
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY
=
"labels"
_META_FILE_NAME
=
"batches.meta"
_CATEGORIES_KEY
=
"label_names"
_categories
=
_cifar10_info
()[
"categories"
]
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
.
startswith
(
"data"
if
self
.
_split
==
"train"
else
"test"
)
@
register_info
(
"cifar100"
)
def
_cifar100_info
()
->
Dict
[
str
,
Any
]:
return
dict
(
categories
=
read_categories_file
(
"cifar100"
))
@
register_dataset
(
"cifar100"
)
class
Cifar100
(
_CifarBase
):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME
=
"cifar-100-python.tar.gz"
_SHA256
=
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY
=
"fine_labels"
_META_FILE_NAME
=
"meta"
_CATEGORIES_KEY
=
"fine_label_names"
_categories
=
_cifar100_info
()[
"categories"
]
def
_is_data_file
(
self
,
data
:
Tuple
[
str
,
Any
])
->
bool
:
path
=
pathlib
.
Path
(
data
[
0
])
return
path
.
name
==
self
.
_split
torchvision/prototype/datasets/_builtin/cifar10.categories
deleted
100644 → 0
View file @
f44f20cf
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment