Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7c9878a4
Unverified
Commit
7c9878a4
authored
Feb 10, 2023
by
Philip Meier
Committed by
GitHub
Feb 10, 2023
Browse files
remove datapoints compatibility for prototype datasets (#7154)
parent
a9d25721
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
25 additions
and
36 deletions
+25
-36
test/test_prototype_datasets_builtin.py
test/test_prototype_datasets_builtin.py
+10
-6
torchvision/prototype/datapoints/_datapoint.py
torchvision/prototype/datapoints/_datapoint.py
+1
-18
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+3
-2
torchvision/prototype/datasets/_builtin/celeba.py
torchvision/prototype/datasets/_builtin/celeba.py
+2
-2
torchvision/prototype/datasets/_builtin/coco.py
torchvision/prototype/datasets/_builtin/coco.py
+2
-3
torchvision/prototype/datasets/_builtin/cub200.py
torchvision/prototype/datasets/_builtin/cub200.py
+2
-2
torchvision/prototype/datasets/_builtin/sbd.py
torchvision/prototype/datasets/_builtin/sbd.py
+5
-3
No files found.
test/test_prototype_datasets_builtin.py
View file @
7c9878a4
...
@@ -21,6 +21,7 @@ from torchdata.datapipes.iter import ShardingFilter, Shuffler
...
@@ -21,6 +21,7 @@ from torchdata.datapipes.iter import ShardingFilter, Shuffler
from
torchdata.datapipes.utils
import
StreamWrapper
from
torchdata.datapipes.utils
import
StreamWrapper
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
datapoints
,
datasets
,
transforms
from
torchvision.prototype
import
datapoints
,
datasets
,
transforms
from
torchvision.prototype.datasets.utils
import
EncodedImage
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
...
@@ -136,18 +137,21 @@ class TestCommon:
...
@@ -136,18 +137,21 @@ class TestCommon:
raise
AssertionError
(
make_msg_and_close
(
"The following streams were not closed after a full iteration:"
))
raise
AssertionError
(
make_msg_and_close
(
"The following streams were not closed after a full iteration:"
))
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
def
test_no_simple_tensors
(
self
,
dataset_mock
,
config
):
def
test_no_
unaccompanied_
simple_tensors
(
self
,
dataset_mock
,
config
):
dataset
,
_
=
dataset_mock
.
load
(
config
)
dataset
,
_
=
dataset_mock
.
load
(
config
)
sample
=
next_consume
(
iter
(
dataset
))
simple_tensors
=
{
simple_tensors
=
{
key
key
for
key
,
value
in
sample
.
items
()
if
torchvision
.
prototype
.
transforms
.
utils
.
is_simple_tensor
(
value
)
for
key
,
value
in
next_consume
(
iter
(
dataset
)).
items
()
if
torchvision
.
prototype
.
transforms
.
utils
.
is_simple_tensor
(
value
)
}
}
if
simple_tensors
:
if
simple_tensors
and
not
any
(
isinstance
(
item
,
(
datapoints
.
Image
,
datapoints
.
Video
,
EncodedImage
))
for
item
in
sample
.
values
()
):
raise
AssertionError
(
raise
AssertionError
(
f
"The values of key(s) "
f
"The values of key(s) "
f
"
{
sequence_to_str
(
sorted
(
simple_tensors
),
separate_last
=
'and '
)
}
contained simple tensors."
f
"
{
sequence_to_str
(
sorted
(
simple_tensors
),
separate_last
=
'and '
)
}
contained simple tensors, "
f
"but didn't find any (encoded) image or video."
)
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
@
parametrize_dataset_mocks
(
DATASET_MOCKS
)
...
...
torchvision/prototype/datapoints/_datapoint.py
View file @
7c9878a4
...
@@ -29,26 +29,9 @@ class Datapoint(torch.Tensor):
...
@@ -29,26 +29,9 @@ class Datapoint(torch.Tensor):
requires_grad
=
data
.
requires_grad
if
isinstance
(
data
,
torch
.
Tensor
)
else
False
requires_grad
=
data
.
requires_grad
if
isinstance
(
data
,
torch
.
Tensor
)
else
False
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# one public again.
def
__new__
(
cls
,
data
:
Any
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
Optional
[
bool
]
=
None
,
)
->
Datapoint
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
tensor
.
as_subclass
(
Datapoint
)
@
classmethod
@
classmethod
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
raise
NotImplementedError
# this method should be made abstract
# raise NotImplementedError
return
tensor
.
as_subclass
(
cls
)
_NO_WRAPPING_EXCEPTIONS
=
{
_NO_WRAPPING_EXCEPTIONS
=
{
torch
.
Tensor
.
clone
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
torch
.
Tensor
.
clone
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
...
...
torchvision/prototype/datasets/_builtin/caltech.py
View file @
7c9878a4
...
@@ -3,9 +3,10 @@ import re
...
@@ -3,9 +3,10 @@ import re
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
hint_sharding
,
hint_sharding
,
...
@@ -115,7 +116,7 @@ class Caltech101(Dataset):
...
@@ -115,7 +116,7 @@ class Caltech101(Dataset):
format
=
"xyxy"
,
format
=
"xyxy"
,
spatial_size
=
image
.
spatial_size
,
spatial_size
=
image
.
spatial_size
,
),
),
contour
=
Datapoint
(
ann
[
"obj_contour"
].
T
),
contour
=
torch
.
as_tensor
(
ann
[
"obj_contour"
].
T
),
)
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
...
...
torchvision/prototype/datasets/_builtin/celeba.py
View file @
7c9878a4
...
@@ -2,9 +2,9 @@ import csv
...
@@ -2,9 +2,9 @@ import csv
import
pathlib
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Dict
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Dict
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchdata.datapipes.iter
import
Filter
,
IterDataPipe
,
IterKeyZipper
,
Mapper
,
Zipper
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -149,7 +149,7 @@ class CelebA(Dataset):
...
@@ -149,7 +149,7 @@ class CelebA(Dataset):
spatial_size
=
image
.
spatial_size
,
spatial_size
=
image
.
spatial_size
,
),
),
landmarks
=
{
landmarks
=
{
landmark
:
Datapoint
((
int
(
landmarks
[
f
"
{
landmark
}
_x"
]),
int
(
landmarks
[
f
"
{
landmark
}
_y"
])))
landmark
:
torch
.
tensor
((
int
(
landmarks
[
f
"
{
landmark
}
_x"
]),
int
(
landmarks
[
f
"
{
landmark
}
_y"
])))
for
landmark
in
{
key
[:
-
2
]
for
key
in
landmarks
.
keys
()}
for
landmark
in
{
key
[:
-
2
]
for
key
in
landmarks
.
keys
()}
},
},
)
)
...
...
torchvision/prototype/datasets/_builtin/coco.py
View file @
7c9878a4
...
@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import (
...
@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import (
UnBatcher
,
UnBatcher
,
)
)
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
,
Mask
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
,
Mask
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -124,8 +123,8 @@ class Coco(Dataset):
...
@@ -124,8 +123,8 @@ class Coco(Dataset):
]
]
)
)
),
),
areas
=
Datapoint
([
ann
[
"area"
]
for
ann
in
anns
]),
areas
=
torch
.
as_tensor
([
ann
[
"area"
]
for
ann
in
anns
]),
crowds
=
Datapoint
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
crowds
=
torch
.
as_tensor
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
bounding_boxes
=
BoundingBox
(
bounding_boxes
=
BoundingBox
(
[
ann
[
"bbox"
]
for
ann
in
anns
],
[
ann
[
"bbox"
]
for
ann
in
anns
],
format
=
"xywh"
,
format
=
"xywh"
,
...
...
torchvision/prototype/datasets/_builtin/cub200.py
View file @
7c9878a4
...
@@ -3,6 +3,7 @@ import functools
...
@@ -3,6 +3,7 @@ import functools
import
pathlib
import
pathlib
from
typing
import
Any
,
BinaryIO
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torchdata.datapipes.iter
import
(
from
torchdata.datapipes.iter
import
(
CSVDictParser
,
CSVDictParser
,
CSVParser
,
CSVParser
,
...
@@ -15,7 +16,6 @@ from torchdata.datapipes.iter import (
...
@@ -15,7 +16,6 @@ from torchdata.datapipes.iter import (
)
)
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchdata.datapipes.map
import
IterToMapConverter
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
from
torchvision.prototype.datapoints
import
BoundingBox
,
Label
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
GDriveResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -162,7 +162,7 @@ class CUB200(Dataset):
...
@@ -162,7 +162,7 @@ class CUB200(Dataset):
format
=
"xyxy"
,
format
=
"xyxy"
,
spatial_size
=
spatial_size
,
spatial_size
=
spatial_size
,
),
),
segmentation
=
Datapoint
(
content
[
"seg"
]),
segmentation
=
torch
.
as_tensor
(
content
[
"seg"
]),
)
)
def
_prepare_sample
(
def
_prepare_sample
(
...
...
torchvision/prototype/datasets/_builtin/sbd.py
View file @
7c9878a4
...
@@ -3,8 +3,8 @@ import re
...
@@ -3,8 +3,8 @@ import re
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
BinaryIO
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchdata.datapipes.iter
import
Demultiplexer
,
Filter
,
IterDataPipe
,
IterKeyZipper
,
LineReader
,
Mapper
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils
import
Dataset
,
EncodedImage
,
HttpResource
,
OnlineResource
from
torchvision.prototype.datasets.utils._internal
import
(
from
torchvision.prototype.datasets.utils._internal
import
(
getitem
,
getitem
,
...
@@ -92,8 +92,10 @@ class SBD(Dataset):
...
@@ -92,8 +92,10 @@ class SBD(Dataset):
image
=
EncodedImage
.
from_file
(
image_buffer
),
image
=
EncodedImage
.
from_file
(
image_buffer
),
ann_path
=
ann_path
,
ann_path
=
ann_path
,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries
=
Datapoint
(
np
.
stack
([
raw_boundary
.
toarray
()
for
raw_boundary
in
anns
[
"Boundaries"
].
item
()])),
boundaries
=
torch
.
as_tensor
(
segmentation
=
Datapoint
(
anns
[
"Segmentation"
].
item
()),
np
.
stack
([
raw_boundary
.
toarray
()
for
raw_boundary
in
anns
[
"Boundaries"
].
item
()])
),
segmentation
=
torch
.
as_tensor
(
anns
[
"Segmentation"
].
item
()),
)
)
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
def
_datapipe
(
self
,
resource_dps
:
List
[
IterDataPipe
])
->
IterDataPipe
[
Dict
[
str
,
Any
]]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment