Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
70745705
"torchvision/vscode:/vscode.git/clone" did not exist on "8263c8a1dc0cd597a46a9c72e7da792e666e9890"
Unverified
Commit
70745705
authored
Feb 13, 2023
by
Philip Meier
Committed by
GitHub
Feb 13, 2023
Browse files
call dataset wrapper with idx and sample (#7235)
parent
b030e936
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
14 deletions
+20
-14
torchvision/prototype/datapoints/_dataset_wrapper.py
torchvision/prototype/datapoints/_dataset_wrapper.py
+20
-14
No files found.
torchvision/prototype/datapoints/_dataset_wrapper.py
View file @
70745705
...
...
@@ -74,7 +74,7 @@ class VisionDatasetDatapointWrapper(Dataset):
# of this class
sample
=
self
.
_dataset
[
idx
]
sample
=
self
.
_wrapper
(
sample
)
sample
=
self
.
_wrapper
(
idx
,
sample
)
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms`
...
...
@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
def
classification_wrapper_factory
(
dataset
):
return
identity
def
wrapper
(
idx
,
sample
):
return
sample
return
wrapper
for
dataset_cls
in
[
...
...
@@ -143,7 +146,7 @@ for dataset_cls in [
def
segmentation_wrapper_factory
(
dataset
):
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
mask
=
sample
return
image
,
pil_image_to_mask
(
mask
)
...
...
@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
f
"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
)
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
video
,
audio
,
label
=
sample
video
=
datapoints
.
Video
(
video
)
...
...
@@ -201,14 +204,17 @@ def coco_dectection_wrapper_factory(dataset):
)
return
torch
.
from_numpy
(
mask
.
decode
(
segmentation
))
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image_id
=
dataset
.
ids
[
idx
]
image
,
target
=
sample
if
not
target
:
return
image
,
dict
(
image_id
=
image_id
)
batched_target
=
list_of_dicts_to_dict_of_lists
(
target
)
image_ids
=
batched_target
.
pop
(
"image_id"
)
image_id
=
batched_target
[
"image_id"
]
=
image_ids
.
pop
()
assert
all
(
other_image_id
==
image_id
for
other_image_id
in
image_ids
)
batched_target
[
"image_id"
]
=
image_id
spatial_size
=
tuple
(
F
.
get_spatial_size
(
image
))
batched_target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
...
...
@@ -259,7 +265,7 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC
@
WRAPPER_FACTORIES
.
register
(
datasets
.
VOCDetection
)
def
voc_detection_wrapper_factory
(
dataset
):
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
batched_instances
=
list_of_dicts_to_dict_of_lists
(
target
[
"annotation"
][
"object"
])
...
...
@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
if
any
(
target_type
in
dataset
.
target_type
for
target_type
in
[
"attr"
,
"landmarks"
]):
raise_not_supported
(
"`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`"
)
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
target
=
wrap_target_by_type
(
...
...
@@ -318,7 +324,7 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Kitti
)
def
kitti_wrapper_factory
(
dataset
):
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
if
target
is
not
None
:
...
...
@@ -336,7 +342,7 @@ def kitti_wrapper_factory(dataset):
@
WRAPPER_FACTORIES
.
register
(
datasets
.
OxfordIIITPet
)
def
oxford_iiit_pet_wrapper_factor
(
dataset
):
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
if
target
is
not
None
:
...
...
@@ -371,7 +377,7 @@ def cityscapes_wrapper_factory(dataset):
labels
.
append
(
label
)
return
dict
(
masks
=
datapoints
.
Mask
(
torch
.
stack
(
masks
)),
labels
=
torch
.
stack
(
labels
))
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
target
=
wrap_target_by_type
(
...
...
@@ -390,7 +396,7 @@ def cityscapes_wrapper_factory(dataset):
@
WRAPPER_FACTORIES
.
register
(
datasets
.
WIDERFace
)
def
widerface_wrapper
(
dataset
):
def
wrapper
(
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
if
target
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment