Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
10e658cd
Unverified
Commit
10e658cd
authored
Dec 14, 2021
by
Philip Meier
Committed by
GitHub
Dec 14, 2021
Browse files
fix caltech and imagenet prototype datasets (#5032)
parent
d98cccb0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
29 deletions
+29
-29
test/test_prototype_builtin_datasets.py
test/test_prototype_builtin_datasets.py
+13
-5
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+2
-2
torchvision/prototype/datasets/_builtin/imagenet.py
torchvision/prototype/datasets/_builtin/imagenet.py
+14
-22
No files found.
test/test_prototype_builtin_datasets.py
View file @
10e658cd
...
@@ -2,9 +2,10 @@ import io
...
@@ -2,9 +2,10 @@ import io
import
builtin_dataset_mocks
import
builtin_dataset_mocks
import
pytest
import
pytest
import
torch
from
torch.utils.data.graph
import
traverse
from
torch.utils.data.graph
import
traverse
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchvision.prototype
import
datasets
,
feature
s
from
torchvision.prototype
import
datasets
,
transform
s
from
torchvision.prototype.datasets._api
import
DEFAULT_DECODER
from
torchvision.prototype.datasets._api
import
DEFAULT_DECODER
from
torchvision.prototype.utils._internal
import
sequence_to_str
from
torchvision.prototype.utils._internal
import
sequence_to_str
...
@@ -88,10 +89,17 @@ class TestCommon:
...
@@ -88,10 +89,17 @@ class TestCommon:
)
)
@
dataset_parametrization
(
decoder
=
DEFAULT_DECODER
)
@
dataset_parametrization
(
decoder
=
DEFAULT_DECODER
)
def
test_at_least_one_feature
(
self
,
dataset
,
mock_info
):
def
test_no_vanilla_tensors
(
self
,
dataset
,
mock_info
):
sample
=
next
(
iter
(
dataset
))
vanilla_tensors
=
{
key
for
key
,
value
in
next
(
iter
(
dataset
)).
items
()
if
type
(
value
)
is
torch
.
Tensor
}
if
not
any
(
isinstance
(
value
,
features
.
Feature
)
for
value
in
sample
.
values
()):
if
vanilla_tensors
:
raise
AssertionError
(
"The sample contained no feature."
)
raise
AssertionError
(
f
"The values of key(s) "
f
"
{
sequence_to_str
(
sorted
(
vanilla_tensors
),
separate_last
=
'and '
)
}
contained vanilla tensors."
)
@
dataset_parametrization
()
def
test_transformable
(
self
,
dataset
,
mock_info
):
next
(
iter
(
dataset
.
map
(
transforms
.
Identity
())))
@
dataset_parametrization
()
@
dataset_parametrization
()
def
test_traversable
(
self
,
dataset
,
mock_info
):
def
test_traversable
(
self
,
dataset
,
mock_info
):
...
...
torchvision/prototype/datasets/_builtin/caltech.py
View file @
10e658cd
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import (
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType
,
DatasetType
,
)
)
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
,
read_mat
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
,
read_mat
from
torchvision.prototype.features
import
Label
,
BoundingBox
from
torchvision.prototype.features
import
Label
,
BoundingBox
,
Feature
class
Caltech101
(
Dataset
):
class
Caltech101
(
Dataset
):
...
@@ -98,7 +98,7 @@ class Caltech101(Dataset):
...
@@ -98,7 +98,7 @@ class Caltech101(Dataset):
ann
=
read_mat
(
ann_buffer
)
ann
=
read_mat
(
ann_buffer
)
bbox
=
BoundingBox
(
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
format
=
"xyxy"
)
bbox
=
BoundingBox
(
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
format
=
"xyxy"
)
contour
=
torch
.
tensor
(
ann
[
"obj_contour"
].
T
)
contour
=
Feature
(
ann
[
"obj_contour"
].
T
)
return
dict
(
return
dict
(
category
=
category
,
category
=
category
,
...
...
torchvision/prototype/datasets/_builtin/imagenet.py
View file @
10e658cd
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem
,
getitem
,
read_mat
,
read_mat
,
)
)
from
torchvision.prototype.features
import
Label
,
DEFAULT
from
torchvision.prototype.features
import
Label
from
torchvision.prototype.utils._internal
import
FrozenMapping
from
torchvision.prototype.utils._internal
import
FrozenMapping
...
@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource):
...
@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource):
super
().
__init__
(
"Register on https://image-net.org/ and follow the instructions there."
,
**
kwargs
)
super
().
__init__
(
"Register on https://image-net.org/ and follow the instructions there."
,
**
kwargs
)
class
ImageNetLabel
(
Label
):
wnid
:
Optional
[
str
]
@
classmethod
def
_parse_meta_data
(
cls
,
category
:
Optional
[
str
]
=
DEFAULT
,
# type: ignore[assignment]
wnid
:
Optional
[
str
]
=
DEFAULT
,
# type: ignore[assignment]
)
->
Dict
[
str
,
Tuple
[
Any
,
Any
]]:
return
dict
(
category
=
(
category
,
None
),
wnid
=
(
wnid
,
None
))
class
ImageNet
(
Dataset
):
class
ImageNet
(
Dataset
):
def
_make_info
(
self
)
->
DatasetInfo
:
def
_make_info
(
self
)
->
DatasetInfo
:
name
=
"imagenet"
name
=
"imagenet"
...
@@ -97,12 +85,12 @@ class ImageNet(Dataset):
...
@@ -97,12 +85,12 @@ class ImageNet(Dataset):
_TRAIN_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"(?P<wnid>n\d{8})_\d+[.]JPEG"
)
_TRAIN_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"(?P<wnid>n\d{8})_\d+[.]JPEG"
)
def
_collate_train_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
ImageNetLabel
,
Tuple
[
str
,
io
.
IOBase
]]:
def
_collate_train_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
Tuple
[
Label
,
str
,
str
]
,
Tuple
[
str
,
io
.
IOBase
]]:
path
=
pathlib
.
Path
(
data
[
0
])
path
=
pathlib
.
Path
(
data
[
0
])
wnid
=
self
.
_TRAIN_IMAGE_NAME_PATTERN
.
match
(
path
.
name
).
group
(
"wnid"
)
# type: ignore[union-attr]
wnid
=
self
.
_TRAIN_IMAGE_NAME_PATTERN
.
match
(
path
.
name
).
group
(
"wnid"
)
# type: ignore[union-attr]
category
=
self
.
wnid_to_category
[
wnid
]
category
=
self
.
wnid_to_category
[
wnid
]
label
=
ImageNet
Label
(
self
.
categories
.
index
(
category
),
category
=
category
,
wnid
=
wnid
)
label
_data
=
(
Label
(
self
.
categories
.
index
(
category
)
)
,
category
,
wnid
)
return
label
,
data
return
label
_data
,
data
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
...
@@ -112,28 +100,32 @@ class ImageNet(Dataset):
...
@@ -112,28 +100,32 @@ class ImageNet(Dataset):
def
_collate_val_data
(
def
_collate_val_data
(
self
,
data
:
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
str
,
io
.
IOBase
]]
self
,
data
:
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
str
,
io
.
IOBase
]]
)
->
Tuple
[
ImageNetLabel
,
Tuple
[
str
,
io
.
IOBase
]]:
)
->
Tuple
[
Tuple
[
Label
,
str
,
str
]
,
Tuple
[
str
,
io
.
IOBase
]]:
label_data
,
image_data
=
data
label_data
,
image_data
=
data
_
,
label
=
label_data
_
,
label
=
label_data
category
=
self
.
categories
[
label
]
category
=
self
.
categories
[
label
]
wnid
=
self
.
category_to_wnid
[
category
]
wnid
=
self
.
category_to_wnid
[
category
]
return
ImageNet
Label
(
label
,
category
=
category
,
wnid
=
wnid
),
image_data
return
(
Label
(
label
)
,
category
,
wnid
),
image_data
def
_collate_test_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
None
,
Tuple
[
str
,
io
.
IOBase
]]:
def
_collate_test_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
None
,
Tuple
[
str
,
io
.
IOBase
]]:
return
None
,
data
return
None
,
data
def
_collate_and_decode_sample
(
def
_collate_and_decode_sample
(
self
,
self
,
data
:
Tuple
[
Optional
[
ImageNetLabel
],
Tuple
[
str
,
io
.
IOBase
]],
data
:
Tuple
[
Optional
[
Tuple
[
Label
,
str
,
str
]
],
Tuple
[
str
,
io
.
IOBase
]],
*
,
*
,
decoder
:
Optional
[
Callable
[[
io
.
IOBase
],
torch
.
Tensor
]],
decoder
:
Optional
[
Callable
[[
io
.
IOBase
],
torch
.
Tensor
]],
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
label
,
(
path
,
buffer
)
=
data
label_data
,
(
path
,
buffer
)
=
data
return
dict
(
sample
=
dict
(
path
=
path
,
path
=
path
,
image
=
decoder
(
buffer
)
if
decoder
else
buffer
,
image
=
decoder
(
buffer
)
if
decoder
else
buffer
,
label
=
label
,
)
)
if
label_data
:
sample
.
update
(
dict
(
zip
((
"label"
,
"category"
,
"wnid"
),
label_data
)))
return
sample
def
_make_datapipe
(
def
_make_datapipe
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment