Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
bdf16222
Unverified
Commit
bdf16222
authored
Jul 31, 2023
by
Philip Meier
Committed by
GitHub
Jul 31, 2023
Browse files
add support for instance checks on dataset wrappers (#7239)
parent
9b4ec8df
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
11 deletions
+17
-11
references/detection/coco_utils.py
references/detection/coco_utils.py
+1
-3
references/detection/group_by_aspect_ratio.py
references/detection/group_by_aspect_ratio.py
+1
-3
test/datasets_utils.py
test/datasets_utils.py
+4
-2
torchvision/datapoints/_dataset_wrapper.py
torchvision/datapoints/_dataset_wrapper.py
+11
-3
No files found.
references/detection/coco_utils.py
View file @
bdf16222
...
...
@@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset):
break
if
isinstance
(
dataset
,
torch
.
utils
.
data
.
Subset
):
dataset
=
dataset
.
dataset
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
)
or
isinstance
(
getattr
(
dataset
,
"_dataset"
,
None
),
torchvision
.
datasets
.
CocoDetection
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
return
dataset
.
coco
return
convert_to_coco_api
(
dataset
)
...
...
references/detection/group_by_aspect_ratio.py
View file @
bdf16222
...
...
@@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None):
if
hasattr
(
dataset
,
"get_height_and_width"
):
return
_compute_aspect_ratios_custom_dataset
(
dataset
,
indices
)
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
)
or
isinstance
(
getattr
(
dataset
,
"_dataset"
,
None
),
torchvision
.
datasets
.
CocoDetection
):
if
isinstance
(
dataset
,
torchvision
.
datasets
.
CocoDetection
):
return
_compute_aspect_ratios_coco_dataset
(
dataset
,
indices
)
if
isinstance
(
dataset
,
torchvision
.
datasets
.
VOCDetection
):
...
...
test/datasets_utils.py
View file @
bdf16222
...
...
@@ -571,7 +571,7 @@ class DatasetTestCase(unittest.TestCase):
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
try
:
with
self
.
create_dataset
(
config
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
for
target_keys
in
[
None
,
"all"
]:
if
target_keys
is
not
None
and
self
.
DATASET_CLASS
not
in
{
torchvision
.
datasets
.
CocoDetection
,
...
...
@@ -584,8 +584,10 @@ class DatasetTestCase(unittest.TestCase):
continue
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
target_keys
)
wrapped_sample
=
wrapped_dataset
[
0
]
assert
isinstance
(
wrapped_dataset
,
self
.
DATASET_CLASS
)
assert
len
(
wrapped_dataset
)
==
info
[
"num_examples"
]
wrapped_sample
=
wrapped_dataset
[
0
]
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
Datapoint
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
except
TypeError
as
error
:
msg
=
f
"No wrapper exists for dataset class
{
type
(
dataset
).
__name__
}
"
...
...
torchvision/datapoints/_dataset_wrapper.py
View file @
bdf16222
...
...
@@ -8,7 +8,6 @@ import contextlib
from
collections
import
defaultdict
import
torch
from
torch.utils.data
import
Dataset
from
torchvision
import
datapoints
,
datasets
from
torchvision.transforms.v2
import
functional
as
F
...
...
@@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
f
"but got
{
target_keys
}
"
)
return
VisionDatasetDatapointWrapper
(
dataset
,
target_keys
)
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need.
wrapped_dataset_cls
=
type
(
f
"Wrapped
{
type
(
dataset
).
__name__
}
"
,
(
VisionDatasetDatapointWrapper
,
type
(
dataset
)),
{})
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object.
return
wrapped_dataset_cls
(
dataset
,
target_keys
)
class
WrapperFactories
(
dict
):
...
...
@@ -117,7 +125,7 @@ class WrapperFactories(dict):
WRAPPER_FACTORIES
=
WrapperFactories
()
class
VisionDatasetDatapointWrapper
(
Dataset
)
:
class
VisionDatasetDatapointWrapper
:
def
__init__
(
self
,
dataset
,
target_keys
):
dataset_cls
=
type
(
dataset
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment