Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
10e658cd
"src/vscode:/vscode.git/clone" did not exist on "421763fb6791261e004a2833263425cddc6bf783"
Unverified
Commit
10e658cd
authored
Dec 14, 2021
by
Philip Meier
Committed by
GitHub
Dec 14, 2021
Browse files
fix caltech and imagenet prototype datasets (#5032)
parent
d98cccb0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
29 deletions
+29
-29
test/test_prototype_builtin_datasets.py
test/test_prototype_builtin_datasets.py
+13
-5
torchvision/prototype/datasets/_builtin/caltech.py
torchvision/prototype/datasets/_builtin/caltech.py
+2
-2
torchvision/prototype/datasets/_builtin/imagenet.py
torchvision/prototype/datasets/_builtin/imagenet.py
+14
-22
No files found.
test/test_prototype_builtin_datasets.py
View file @
10e658cd
...
@@ -2,9 +2,10 @@ import io
...
@@ -2,9 +2,10 @@ import io
import
builtin_dataset_mocks
import
builtin_dataset_mocks
import
pytest
import
pytest
import
torch
from
torch.utils.data.graph
import
traverse
from
torch.utils.data.graph
import
traverse
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchdata.datapipes.iter
import
IterDataPipe
from
torchvision.prototype
import
datasets
,
feature
s
from
torchvision.prototype
import
datasets
,
transform
s
from
torchvision.prototype.datasets._api
import
DEFAULT_DECODER
from
torchvision.prototype.datasets._api
import
DEFAULT_DECODER
from
torchvision.prototype.utils._internal
import
sequence_to_str
from
torchvision.prototype.utils._internal
import
sequence_to_str
...
@@ -88,10 +89,17 @@ class TestCommon:
...
@@ -88,10 +89,17 @@ class TestCommon:
)
)
@
dataset_parametrization
(
decoder
=
DEFAULT_DECODER
)
@
dataset_parametrization
(
decoder
=
DEFAULT_DECODER
)
def
test_at_least_one_feature
(
self
,
dataset
,
mock_info
):
def
test_no_vanilla_tensors
(
self
,
dataset
,
mock_info
):
sample
=
next
(
iter
(
dataset
))
vanilla_tensors
=
{
key
for
key
,
value
in
next
(
iter
(
dataset
)).
items
()
if
type
(
value
)
is
torch
.
Tensor
}
if
not
any
(
isinstance
(
value
,
features
.
Feature
)
for
value
in
sample
.
values
()):
if
vanilla_tensors
:
raise
AssertionError
(
"The sample contained no feature."
)
raise
AssertionError
(
f
"The values of key(s) "
f
"
{
sequence_to_str
(
sorted
(
vanilla_tensors
),
separate_last
=
'and '
)
}
contained vanilla tensors."
)
@
dataset_parametrization
()
def
test_transformable
(
self
,
dataset
,
mock_info
):
next
(
iter
(
dataset
.
map
(
transforms
.
Identity
())))
@
dataset_parametrization
()
@
dataset_parametrization
()
def
test_traversable
(
self
,
dataset
,
mock_info
):
def
test_traversable
(
self
,
dataset
,
mock_info
):
...
...
torchvision/prototype/datasets/_builtin/caltech.py
View file @
10e658cd
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import (
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType
,
DatasetType
,
)
)
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
,
read_mat
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
,
read_mat
from
torchvision.prototype.features
import
Label
,
BoundingBox
from
torchvision.prototype.features
import
Label
,
BoundingBox
,
Feature
class
Caltech101
(
Dataset
):
class
Caltech101
(
Dataset
):
...
@@ -98,7 +98,7 @@ class Caltech101(Dataset):
...
@@ -98,7 +98,7 @@ class Caltech101(Dataset):
ann
=
read_mat
(
ann_buffer
)
ann
=
read_mat
(
ann_buffer
)
bbox
=
BoundingBox
(
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
format
=
"xyxy"
)
bbox
=
BoundingBox
(
ann
[
"box_coord"
].
astype
(
np
.
int64
).
squeeze
()[[
2
,
0
,
3
,
1
]],
format
=
"xyxy"
)
contour
=
torch
.
tensor
(
ann
[
"obj_contour"
].
T
)
contour
=
Feature
(
ann
[
"obj_contour"
].
T
)
return
dict
(
return
dict
(
category
=
category
,
category
=
category
,
...
...
torchvision/prototype/datasets/_builtin/imagenet.py
View file @
10e658cd
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem
,
getitem
,
read_mat
,
read_mat
,
)
)
from
torchvision.prototype.features
import
Label
,
DEFAULT
from
torchvision.prototype.features
import
Label
from
torchvision.prototype.utils._internal
import
FrozenMapping
from
torchvision.prototype.utils._internal
import
FrozenMapping
...
@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource):
...
@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource):
super
().
__init__
(
"Register on https://image-net.org/ and follow the instructions there."
,
**
kwargs
)
super
().
__init__
(
"Register on https://image-net.org/ and follow the instructions there."
,
**
kwargs
)
class
ImageNetLabel
(
Label
):
wnid
:
Optional
[
str
]
@
classmethod
def
_parse_meta_data
(
cls
,
category
:
Optional
[
str
]
=
DEFAULT
,
# type: ignore[assignment]
wnid
:
Optional
[
str
]
=
DEFAULT
,
# type: ignore[assignment]
)
->
Dict
[
str
,
Tuple
[
Any
,
Any
]]:
return
dict
(
category
=
(
category
,
None
),
wnid
=
(
wnid
,
None
))
class
ImageNet
(
Dataset
):
class
ImageNet
(
Dataset
):
def
_make_info
(
self
)
->
DatasetInfo
:
def
_make_info
(
self
)
->
DatasetInfo
:
name
=
"imagenet"
name
=
"imagenet"
...
@@ -97,12 +85,12 @@ class ImageNet(Dataset):
...
@@ -97,12 +85,12 @@ class ImageNet(Dataset):
_TRAIN_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"(?P<wnid>n\d{8})_\d+[.]JPEG"
)
_TRAIN_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"(?P<wnid>n\d{8})_\d+[.]JPEG"
)
def
_collate_train_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
ImageNetLabel
,
Tuple
[
str
,
io
.
IOBase
]]:
def
_collate_train_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
Tuple
[
Label
,
str
,
str
]
,
Tuple
[
str
,
io
.
IOBase
]]:
path
=
pathlib
.
Path
(
data
[
0
])
path
=
pathlib
.
Path
(
data
[
0
])
wnid
=
self
.
_TRAIN_IMAGE_NAME_PATTERN
.
match
(
path
.
name
).
group
(
"wnid"
)
# type: ignore[union-attr]
wnid
=
self
.
_TRAIN_IMAGE_NAME_PATTERN
.
match
(
path
.
name
).
group
(
"wnid"
)
# type: ignore[union-attr]
category
=
self
.
wnid_to_category
[
wnid
]
category
=
self
.
wnid_to_category
[
wnid
]
label
=
ImageNet
Label
(
self
.
categories
.
index
(
category
),
category
=
category
,
wnid
=
wnid
)
label
_data
=
(
Label
(
self
.
categories
.
index
(
category
)
)
,
category
,
wnid
)
return
label
,
data
return
label
_data
,
data
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
_VAL_TEST_IMAGE_NAME_PATTERN
=
re
.
compile
(
r
"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG"
)
...
@@ -112,28 +100,32 @@ class ImageNet(Dataset):
...
@@ -112,28 +100,32 @@ class ImageNet(Dataset):
def
_collate_val_data
(
def
_collate_val_data
(
self
,
data
:
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
str
,
io
.
IOBase
]]
self
,
data
:
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
str
,
io
.
IOBase
]]
)
->
Tuple
[
ImageNetLabel
,
Tuple
[
str
,
io
.
IOBase
]]:
)
->
Tuple
[
Tuple
[
Label
,
str
,
str
]
,
Tuple
[
str
,
io
.
IOBase
]]:
label_data
,
image_data
=
data
label_data
,
image_data
=
data
_
,
label
=
label_data
_
,
label
=
label_data
category
=
self
.
categories
[
label
]
category
=
self
.
categories
[
label
]
wnid
=
self
.
category_to_wnid
[
category
]
wnid
=
self
.
category_to_wnid
[
category
]
return
ImageNet
Label
(
label
,
category
=
category
,
wnid
=
wnid
),
image_data
return
(
Label
(
label
)
,
category
,
wnid
),
image_data
def
_collate_test_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
None
,
Tuple
[
str
,
io
.
IOBase
]]:
def
_collate_test_data
(
self
,
data
:
Tuple
[
str
,
io
.
IOBase
])
->
Tuple
[
None
,
Tuple
[
str
,
io
.
IOBase
]]:
return
None
,
data
return
None
,
data
def
_collate_and_decode_sample
(
def
_collate_and_decode_sample
(
self
,
self
,
data
:
Tuple
[
Optional
[
ImageNetLabel
],
Tuple
[
str
,
io
.
IOBase
]],
data
:
Tuple
[
Optional
[
Tuple
[
Label
,
str
,
str
]
],
Tuple
[
str
,
io
.
IOBase
]],
*
,
*
,
decoder
:
Optional
[
Callable
[[
io
.
IOBase
],
torch
.
Tensor
]],
decoder
:
Optional
[
Callable
[[
io
.
IOBase
],
torch
.
Tensor
]],
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
label
,
(
path
,
buffer
)
=
data
label_data
,
(
path
,
buffer
)
=
data
return
dict
(
sample
=
dict
(
path
=
path
,
path
=
path
,
image
=
decoder
(
buffer
)
if
decoder
else
buffer
,
image
=
decoder
(
buffer
)
if
decoder
else
buffer
,
label
=
label
,
)
)
if
label_data
:
sample
.
update
(
dict
(
zip
((
"label"
,
"category"
,
"wnid"
),
label_data
)))
return
sample
def
_make_datapipe
(
def
_make_datapipe
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment