Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
408c9bea
Unverified
Commit
408c9bea
authored
Nov 18, 2021
by
Philip Meier
Committed by
GitHub
Nov 18, 2021
Browse files
make prototype datasets traversable (#4950)
Co-authored-by:
Francisco Massa
<
fvsmassa@gmail.com
>
parent
59baae99
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
17 deletions
+25
-17
test/test_prototype_builtin_datasets.py
test/test_prototype_builtin_datasets.py
+5
-0
torchvision/prototype/datasets/utils/_internal.py
torchvision/prototype/datasets/utils/_internal.py
+20
-17
No files found.
test/test_prototype_builtin_datasets.py
View file @
408c9bea
...
@@ -2,6 +2,7 @@ import io
...
@@ -2,6 +2,7 @@ import io
import
builtin_dataset_mocks
import
builtin_dataset_mocks
import
pytest
import
pytest
from
torch.utils.data.graph
import
traverse
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchvision.prototype
import
datasets
,
features
from
torchvision.prototype
import
datasets
,
features
from
torchvision.prototype.datasets._api
import
DEFAULT_DECODER
from
torchvision.prototype.datasets._api
import
DEFAULT_DECODER
...
@@ -83,6 +84,10 @@ class TestCommon:
...
@@ -83,6 +84,10 @@ class TestCommon:
if
not
any
(
isinstance
(
value
,
features
.
Feature
)
for
value
in
sample
.
values
()):
if
not
any
(
isinstance
(
value
,
features
.
Feature
)
for
value
in
sample
.
values
()):
raise
AssertionError
(
"The sample contained no feature."
)
raise
AssertionError
(
"The sample contained no feature."
)
@
dataset_parametrization
()
def
test_traversable
(
self
,
dataset
,
mock_info
):
traverse
(
dataset
)
class
TestQMNIST
:
class
TestQMNIST
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
torchvision/prototype/datasets/utils/_internal.py
View file @
408c9bea
import
enum
import
enum
import
functools
import
gzip
import
gzip
import
io
import
io
import
lzma
import
lzma
...
@@ -101,35 +102,37 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
...
@@ -101,35 +102,37 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
yield
from
enumerate
(
self
.
datapipe
,
self
.
start
)
yield
from
enumerate
(
self
.
datapipe
,
self
.
start
)
def
_getitem_closure
(
obj
:
Any
,
*
,
items
:
Tuple
[
Any
,
...])
->
Any
:
for
item
in
items
:
obj
=
obj
[
item
]
return
obj
def
getitem
(
*
items
:
Any
)
->
Callable
[[
Any
],
Any
]:
def
getitem
(
*
items
:
Any
)
->
Callable
[[
Any
],
Any
]:
def
wrapper
(
obj
:
Any
)
->
Any
:
return
functools
.
partial
(
_getitem_closure
,
items
=
items
)
for
item
in
items
:
obj
=
obj
[
item
]
return
obj
def
_path_attribute_accessor
(
path
:
pathlib
.
Path
,
*
,
name
:
str
)
->
D
:
return
cast
(
D
,
getattr
(
path
,
name
))
return
wrapper
def
_path_accessor_closure
(
data
:
Tuple
[
str
,
Any
],
*
,
getter
:
Callable
[[
pathlib
.
Path
],
D
])
->
D
:
return
getter
(
pathlib
.
Path
(
data
[
0
]))
def
path_accessor
(
getter
:
Union
[
str
,
Callable
[[
pathlib
.
Path
],
D
]])
->
Callable
[[
Tuple
[
str
,
Any
]],
D
]:
def
path_accessor
(
getter
:
Union
[
str
,
Callable
[[
pathlib
.
Path
],
D
]])
->
Callable
[[
Tuple
[
str
,
Any
]],
D
]:
if
isinstance
(
getter
,
str
):
if
isinstance
(
getter
,
str
):
name
=
getter
getter
=
functools
.
partial
(
_path_attribute_accessor
,
name
=
getter
)
def
getter
(
path
:
pathlib
.
Path
)
->
D
:
return
functools
.
partial
(
_path_accessor_closure
,
getter
=
getter
)
return
cast
(
D
,
getattr
(
path
,
name
))
def
wrapper
(
data
:
Tuple
[
str
,
Any
])
->
D
:
return
getter
(
pathlib
.
Path
(
data
[
0
]))
# type: ignore[operator]
return
wrapper
def
_path_comparator_closure
(
data
:
Tuple
[
str
,
Any
],
*
,
accessor
:
Callable
[[
Tuple
[
str
,
Any
]],
D
],
value
:
D
)
->
bool
:
return
accessor
(
data
)
==
value
def
path_comparator
(
getter
:
Union
[
str
,
Callable
[[
pathlib
.
Path
],
D
]],
value
:
D
)
->
Callable
[[
Tuple
[
str
,
Any
]],
bool
]:
def
path_comparator
(
getter
:
Union
[
str
,
Callable
[[
pathlib
.
Path
],
D
]],
value
:
D
)
->
Callable
[[
Tuple
[
str
,
Any
]],
bool
]:
accessor
=
path_accessor
(
getter
)
return
functools
.
partial
(
_path_comparator_closure
,
accessor
=
path_accessor
(
getter
),
value
=
value
)
def
wrapper
(
data
:
Tuple
[
str
,
Any
])
->
bool
:
return
accessor
(
data
)
==
value
return
wrapper
class
CompressionType
(
enum
.
Enum
):
class
CompressionType
(
enum
.
Enum
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment