Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d744da93
Unverified
Commit
d744da93
authored
Feb 14, 2023
by
Philip Meier
Committed by
GitHub
Feb 14, 2023
Browse files
allow subclasses in dataset wrappers (#7236)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
b570f2c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
15 deletions
+76
-15
test/datasets_utils.py
test/datasets_utils.py
+1
-1
test/test_prototype_datapoints.py
test/test_prototype_datapoints.py
+44
-0
torchvision/prototype/datapoints/_dataset_wrapper.py
torchvision/prototype/datapoints/_dataset_wrapper.py
+31
-14
No files found.
test/datasets_utils.py
View file @
d744da93
...
...
@@ -596,7 +596,7 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample
=
wrapped_dataset
[
0
]
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
Datapoint
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
except
TypeError
as
error
:
if
str
(
error
).
startswith
(
f
"No wrapper exist for dataset class
{
type
(
dataset
).
__name__
}
"
):
if
str
(
error
).
startswith
(
f
"No wrapper exist
s
for dataset class
{
type
(
dataset
).
__name__
}
"
):
return
raise
error
except
RuntimeError
as
error
:
...
...
test/test_prototype_datapoints.py
View file @
d744da93
import
re
import
pytest
import
torch
from
PIL
import
Image
from
torchvision
import
datasets
from
torchvision.prototype
import
datapoints
...
...
@@ -159,3 +163,43 @@ def test_bbox_instance(data, format):
if
isinstance
(
format
,
str
):
format
=
datapoints
.
BoundingBoxFormat
.
from_str
(
format
.
upper
())
assert
bboxes
.
format
==
format
class
TestDatasetWrapper
:
def
test_unknown_type
(
self
):
unknown_object
=
object
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
"is meant for subclasses of `torchvision.datasets.VisionDataset`"
)
):
datapoints
.
wrap_dataset_for_transforms_v2
(
unknown_object
)
def
test_unknown_dataset
(
self
):
class
MyVisionDataset
(
datasets
.
VisionDataset
):
pass
dataset
=
MyVisionDataset
(
"root"
)
with
pytest
.
raises
(
TypeError
,
match
=
"No wrapper exist"
):
datapoints
.
wrap_dataset_for_transforms_v2
(
dataset
)
def
test_missing_wrapper
(
self
):
dataset
=
datasets
.
FakeData
()
with
pytest
.
raises
(
TypeError
,
match
=
"please open an issue"
):
datapoints
.
wrap_dataset_for_transforms_v2
(
dataset
)
def
test_subclass
(
self
,
mocker
):
sentinel
=
object
()
mocker
.
patch
.
dict
(
datapoints
.
_dataset_wrapper
.
WRAPPER_FACTORIES
,
clear
=
False
,
values
=
{
datasets
.
FakeData
:
lambda
dataset
:
lambda
idx
,
sample
:
sentinel
},
)
class
MyFakeData
(
datasets
.
FakeData
):
pass
dataset
=
MyFakeData
()
wrapped_dataset
=
datapoints
.
wrap_dataset_for_transforms_v2
(
dataset
)
assert
wrapped_dataset
[
0
]
is
sentinel
torchvision/prototype/datapoints/_dataset_wrapper.py
View file @
d744da93
...
...
@@ -39,16 +39,26 @@ WRAPPER_FACTORIES = WrapperFactories()
class
VisionDatasetDatapointWrapper
(
Dataset
):
def
__init__
(
self
,
dataset
):
dataset_cls
=
type
(
dataset
)
wrapper_factory
=
WRAPPER_FACTORIES
.
get
(
dataset_cls
)
if
wrapper_factory
is
None
:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg
=
f
"No wrapper exist for dataset class
{
dataset_cls
.
__name__
}
. Please wrap the output yourself."
if
dataset_cls
in
datasets
.
__dict__
.
values
():
msg
=
(
f
"
{
msg
}
If an automated wrapper for this dataset would be useful for you, "
f
"please open an issue at https://github.com/pytorch/vision/issues."
)
raise
TypeError
(
msg
)
if
not
isinstance
(
dataset
,
datasets
.
VisionDataset
):
raise
TypeError
(
f
"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f
"but got a '
{
dataset_cls
.
__name__
}
' instead."
)
for
cls
in
dataset_cls
.
mro
():
if
cls
in
WRAPPER_FACTORIES
:
wrapper_factory
=
WRAPPER_FACTORIES
[
cls
]
break
elif
cls
is
datasets
.
VisionDataset
:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg
=
f
"No wrapper exists for dataset class
{
dataset_cls
.
__name__
}
. Please wrap the output yourself."
if
dataset_cls
in
datasets
.
__dict__
.
values
():
msg
=
(
f
"
{
msg
}
If an automated wrapper for this dataset would be useful for you, "
f
"please open an issue at https://github.com/pytorch/vision/issues."
)
raise
TypeError
(
msg
)
self
.
_dataset
=
dataset
self
.
_wrapper
=
wrapper_factory
(
dataset
)
...
...
@@ -98,6 +108,13 @@ def identity(item):
return
item
def
identity_wrapper_factory
(
dataset
):
def
wrapper
(
idx
,
sample
):
return
sample
return
wrapper
def
pil_image_to_mask
(
pil_image
):
return
datapoints
.
Mask
(
pil_image
)
...
...
@@ -125,10 +142,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
def
classification_wrapper_factory
(
dataset
):
def
wrapper
(
idx
,
sample
):
return
sample
return
wrapper
return
identity_wrapper_factory
(
dataset
)
for
dataset_cls
in
[
...
...
@@ -237,6 +251,9 @@ def coco_dectection_wrapper_factory(dataset):
return
wrapper
WRAPPER_FACTORIES
.
register
(
datasets
.
CocoCaptions
)(
identity_wrapper_factory
)
VOC_DETECTION_CATEGORIES
=
[
"__background__"
,
"aeroplane"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment