Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
874581cb
Unverified
Commit
874581cb
authored
Dec 01, 2021
by
Philip Meier
Committed by
GitHub
Dec 01, 2021
Browse files
remove vanilla tensors from prototype datasets samples (#5018)
parent
0aa3717d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
32 deletions
+27
-32
torchvision/prototype/datasets/_builtin/coco.py
torchvision/prototype/datasets/_builtin/coco.py
+19
-30
torchvision/prototype/features/_feature.py
torchvision/prototype/features/_feature.py
+1
-1
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+7
-1
No files found.
torchvision/prototype/datasets/_builtin/coco.py
View file @
874581cb
...
@@ -32,23 +32,10 @@ from torchvision.prototype.datasets.utils._internal import (
...
@@ -32,23 +32,10 @@ from torchvision.prototype.datasets.utils._internal import (
getitem
,
getitem
,
path_accessor
,
path_accessor
,
)
)
from
torchvision.prototype.features
import
BoundingBox
,
Label
from
torchvision.prototype.features
import
BoundingBox
,
Label
,
Feature
from
torchvision.prototype.features._feature
import
DEFAULT
from
torchvision.prototype.utils._internal
import
FrozenMapping
from
torchvision.prototype.utils._internal
import
FrozenMapping
class
CocoLabel
(
Label
):
super_category
:
Optional
[
str
]
@
classmethod
def
_parse_meta_data
(
cls
,
category
:
Optional
[
str
]
=
DEFAULT
,
# type: ignore[assignment]
super_category
:
Optional
[
str
]
=
DEFAULT
,
# type: ignore[assignment]
)
->
Dict
[
str
,
Tuple
[
Any
,
Any
]]:
return
dict
(
category
=
(
category
,
None
),
super_category
=
(
super_category
,
None
))
class
Coco
(
Dataset
):
class
Coco
(
Dataset
):
def
_make_info
(
self
)
->
DatasetInfo
:
def
_make_info
(
self
)
->
DatasetInfo
:
name
=
"coco"
name
=
"coco"
...
@@ -111,27 +98,24 @@ class Coco(Dataset):
...
@@ -111,27 +98,24 @@ class Coco(Dataset):
categories
=
[
self
.
info
.
categories
[
label
]
for
label
in
labels
]
categories
=
[
self
.
info
.
categories
[
label
]
for
label
in
labels
]
return
dict
(
return
dict
(
# TODO: create a segmentation feature
# TODO: create a segmentation feature
segmentations
=
torch
.
stack
(
segmentations
=
Feature
(
torch
.
stack
(
[
[
self
.
_segmentation_to_mask
(
ann
[
"segmentation"
],
is_crowd
=
ann
[
"iscrowd"
],
image_size
=
image_size
)
self
.
_segmentation_to_mask
(
ann
[
"segmentation"
],
is_crowd
=
ann
[
"iscrowd"
],
image_size
=
image_size
)
for
ann
in
anns
for
ann
in
anns
]
]
)
),
),
areas
=
torch
.
tensor
([
ann
[
"area"
]
for
ann
in
anns
]),
areas
=
Feature
([
ann
[
"area"
]
for
ann
in
anns
]),
crowds
=
torch
.
tensor
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
crowds
=
Feature
([
ann
[
"iscrowd"
]
for
ann
in
anns
],
dtype
=
torch
.
bool
),
bounding_boxes
=
BoundingBox
(
bounding_boxes
=
BoundingBox
(
[
ann
[
"bbox"
]
for
ann
in
anns
],
[
ann
[
"bbox"
]
for
ann
in
anns
],
format
=
"xywh"
,
format
=
"xywh"
,
image_size
=
image_size
,
image_size
=
image_size
,
),
),
labels
=
[
labels
=
Label
(
labels
),
CocoLabel
(
categories
=
categories
,
label
,
super_categories
=
[
self
.
info
.
extra
.
category_to_super_category
[
category
]
for
category
in
categories
],
category
=
category
,
super_category
=
self
.
info
.
extra
.
category_to_super_category
[
category
],
)
for
label
,
category
in
zip
(
labels
,
categories
)
],
ann_ids
=
[
ann
[
"id"
]
for
ann
in
anns
],
ann_ids
=
[
ann
[
"id"
]
for
ann
in
anns
],
)
)
...
@@ -141,7 +125,12 @@ class Coco(Dataset):
...
@@ -141,7 +125,12 @@ class Coco(Dataset):
ann_ids
=
[
ann
[
"id"
]
for
ann
in
anns
],
ann_ids
=
[
ann
[
"id"
]
for
ann
in
anns
],
)
)
_ANN_DECODERS
=
OrderedDict
([(
"instances"
,
_decode_instances_anns
),
(
"captions"
,
_decode_captions_ann
)])
_ANN_DECODERS
=
OrderedDict
(
[
(
"instances"
,
_decode_instances_anns
),
(
"captions"
,
_decode_captions_ann
),
]
)
_META_FILE_PATTERN
=
re
.
compile
(
_META_FILE_PATTERN
=
re
.
compile
(
fr
"(?P<annotations>(
{
'|'
.
join
(
_ANN_DECODERS
.
keys
())
}
))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
fr
"(?P<annotations>(
{
'|'
.
join
(
_ANN_DECODERS
.
keys
())
}
))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
...
...
torchvision/prototype/features/_feature.py
View file @
874581cb
...
@@ -12,7 +12,7 @@ DEFAULT = object()
...
@@ -12,7 +12,7 @@ DEFAULT = object()
class
Feature
(
torch
.
Tensor
):
class
Feature
(
torch
.
Tensor
):
_META_ATTRS
:
Set
[
str
]
_META_ATTRS
:
Set
[
str
]
=
set
()
_meta_data
:
Dict
[
str
,
Any
]
_meta_data
:
Dict
[
str
,
Any
]
def
__init_subclass__
(
cls
):
def
__init_subclass__
(
cls
):
...
...
torchvision/prototype/transforms/_transform.py
View file @
874581cb
...
@@ -360,7 +360,13 @@ class Transform(nn.Module):
...
@@ -360,7 +360,13 @@ class Transform(nn.Module):
else
:
else
:
feature_type
=
type
(
sample
)
feature_type
=
type
(
sample
)
if
not
self
.
supports
(
feature_type
):
if
not
self
.
supports
(
feature_type
):
if
not
issubclass
(
feature_type
,
features
.
Feature
)
or
feature_type
in
self
.
NO_OP_FEATURE_TYPES
:
if
(
not
issubclass
(
feature_type
,
features
.
Feature
)
# issubclass is not a strict check, but also allows the type checked against. Thus, we need to
# check it separately
or
feature_type
is
features
.
Feature
or
feature_type
in
self
.
NO_OP_FEATURE_TYPES
):
return
sample
return
sample
raise
TypeError
(
raise
TypeError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment