Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
e836b3d8
Unverified
Commit
e836b3d8
authored
Mar 03, 2022
by
Philip Meier
Committed by
GitHub
Mar 03, 2022
Browse files
simplify Feature implementation (#5539)
* simplify Feature implementation * fix mypy
parent
97385df0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
111 additions
and
86 deletions
+111
-86
torchvision/_utils.py
torchvision/_utils.py
+4
-1
torchvision/prototype/features/_bounding_box.py
torchvision/prototype/features/_bounding_box.py
+24
-4
torchvision/prototype/features/_encoded.py
torchvision/prototype/features/_encoded.py
+15
-7
torchvision/prototype/features/_feature.py
torchvision/prototype/features/_feature.py
+16
-48
torchvision/prototype/features/_image.py
torchvision/prototype/features/_image.py
+18
-12
torchvision/prototype/features/_label.py
torchvision/prototype/features/_label.py
+33
-13
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+1
-1
No files found.
torchvision/_utils.py
View file @
e836b3d8
import
enum
from
typing
import
TypeVar
,
Type
T
=
TypeVar
(
"T"
,
bound
=
enum
.
Enum
)
class
StrEnumMeta
(
enum
.
EnumMeta
):
auto
=
enum
.
auto
def
from_str
(
self
,
member
:
str
)
:
def
from_str
(
self
:
Type
[
T
]
,
member
:
str
)
->
T
:
# type: ignore[misc]
try
:
return
self
[
member
]
except
KeyError
:
...
...
torchvision/prototype/features/_bounding_box.py
View file @
e836b3d8
...
...
@@ -22,20 +22,40 @@ class BoundingBox(_Feature):
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
format
:
Union
[
BoundingBoxFormat
,
str
],
image_size
:
Tuple
[
int
,
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
BoundingBox
:
bounding_box
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
)
bounding_box
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
isinstance
(
format
,
str
):
format
=
BoundingBoxFormat
.
from_str
(
format
.
upper
())
bounding_box
.
format
=
format
bounding_box
.
_metadata
.
update
(
dict
(
format
=
format
,
image_size
=
image_size
))
bounding_box
.
image_size
=
image_size
return
bounding_box
@
classmethod
def
new_like
(
cls
,
other
:
BoundingBox
,
data
:
Any
,
*
,
format
:
Optional
[
Union
[
BoundingBoxFormat
,
str
]]
=
None
,
image_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
**
kwargs
:
Any
,
)
->
BoundingBox
:
return
super
().
new_like
(
other
,
data
,
format
=
format
if
format
is
not
None
else
other
.
format
,
image_size
=
image_size
if
image_size
is
not
None
else
other
.
image_size
,
**
kwargs
,
)
def
to_format
(
self
,
format
:
Union
[
str
,
BoundingBoxFormat
])
->
BoundingBox
:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
...
...
torchvision/prototype/features/_encoded.py
View file @
e836b3d8
from
__future__
import
annotations
import
os
import
sys
from
typing
import
BinaryIO
,
Tuple
,
Type
,
TypeVar
,
Union
,
Optional
,
Any
...
...
@@ -13,19 +15,25 @@ D = TypeVar("D", bound="EncodedData")
class
EncodedData
(
_Feature
):
@
classmethod
def
_to_tensor
(
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
],
device
:
Optional
[
torch
.
device
])
->
torch
.
Tensor
:
def
__new__
(
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
EncodedData
:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return
super
().
_
to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
)
return
super
().
_
_new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
@
classmethod
def
from_file
(
cls
:
Type
[
D
],
file
:
BinaryIO
)
->
D
:
return
cls
(
fromfile
(
file
,
dtype
=
torch
.
uint8
,
byte_order
=
sys
.
byteorder
))
def
from_file
(
cls
:
Type
[
D
],
file
:
BinaryIO
,
**
kwargs
:
Any
)
->
D
:
return
cls
(
fromfile
(
file
,
dtype
=
torch
.
uint8
,
byte_order
=
sys
.
byteorder
)
,
**
kwargs
)
@
classmethod
def
from_path
(
cls
:
Type
[
D
],
path
:
Union
[
str
,
os
.
PathLike
])
->
D
:
def
from_path
(
cls
:
Type
[
D
],
path
:
Union
[
str
,
os
.
PathLike
]
,
**
kwargs
:
Any
)
->
D
:
with
open
(
path
,
"rb"
)
as
file
:
return
cls
.
from_file
(
file
)
return
cls
.
from_file
(
file
,
**
kwargs
)
class
EncodedImage
(
EncodedData
):
...
...
torchvision/prototype/features/_feature.py
View file @
e836b3d8
from
typing
import
Any
,
cast
,
Dict
,
Set
,
TypeVar
,
Union
,
Optional
,
Type
,
Callable
,
Tuple
,
Sequence
,
Mapping
from
typing
import
Any
,
cast
,
TypeVar
,
Union
,
Optional
,
Type
,
Callable
,
Tuple
,
Sequence
,
Mapping
import
torch
from
torch._C
import
_TensorBase
,
DisableTorchFunction
...
...
@@ -8,59 +8,22 @@ F = TypeVar("F", bound="_Feature")
class
_Feature
(
torch
.
Tensor
):
_META_ATTRS
:
Set
[
str
]
=
set
()
_metadata
:
Dict
[
str
,
Any
]
def
__init_subclass__
(
cls
)
->
None
:
"""
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
properties to have the same convenient access as regular attributes.
>>> class Foo(_Feature):
... bar: str
... baz: Optional[str]
>>> foo = Foo()
>>> foo.bar
>>> foo.baz
This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
"""
meta_attrs
=
{
attr
for
attr
in
cls
.
__annotations__
.
keys
()
-
cls
.
__dict__
.
keys
()
if
not
attr
.
startswith
(
"_"
)}
for
super_cls
in
cls
.
__mro__
[
1
:]:
if
super_cls
is
_Feature
:
break
meta_attrs
.
update
(
cast
(
Type
[
_Feature
],
super_cls
).
_META_ATTRS
)
cls
.
_META_ATTRS
=
meta_attrs
for
name
in
meta_attrs
:
setattr
(
cls
,
name
,
property
(
cast
(
Callable
[[
F
],
Any
],
lambda
self
,
name
=
name
:
self
.
_metadata
[
name
])))
def
__new__
(
cls
:
Type
[
F
],
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
F
:
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
feature
=
cast
(
return
cast
(
F
,
torch
.
Tensor
.
_make_subclass
(
cast
(
_TensorBase
,
cls
),
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
),
# requires_grad
False
,
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
),
# type: ignore[arg-type]
requires_grad
,
),
)
feature
.
_metadata
=
dict
()
return
feature
@
classmethod
def
_to_tensor
(
self
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
],
device
:
Optional
[
torch
.
device
])
->
torch
.
Tensor
:
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
)
@
classmethod
def
new_like
(
...
...
@@ -69,12 +32,17 @@ class _Feature(torch.Tensor):
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
]]
=
None
,
**
metadata
:
Any
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
Optional
[
bool
]
=
None
,
**
kwargs
:
Any
,
)
->
F
:
_metadata
=
other
.
_metadata
.
copy
()
_metadata
.
update
(
metadata
)
return
cls
(
data
,
dtype
=
dtype
or
other
.
dtype
,
device
=
device
or
other
.
device
,
**
_metadata
)
return
cls
(
data
,
dtype
=
dtype
if
dtype
is
not
None
else
other
.
dtype
,
device
=
device
if
device
is
not
None
else
other
.
device
,
requires_grad
=
requires_grad
if
requires_grad
is
not
None
else
other
.
requires_grad
,
**
kwargs
,
)
@
classmethod
def
__torch_function__
(
...
...
torchvision/prototype/features/_image.py
View file @
e836b3d8
...
...
@@ -26,11 +26,17 @@ class Image(_Feature):
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
color_space
:
Optional
[
Union
[
ColorSpace
,
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
Image
:
image
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
)
data
=
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
)
# type: ignore[arg-type]
if
data
.
ndim
<
2
:
raise
ValueError
elif
data
.
ndim
==
2
:
data
=
data
.
unsqueeze
(
0
)
image
=
super
().
__new__
(
cls
,
data
,
requires_grad
=
requires_grad
)
if
color_space
is
None
:
color_space
=
cls
.
guess_color_space
(
image
)
...
...
@@ -38,19 +44,19 @@ class Image(_Feature):
warnings
.
warn
(
"Unable to guess a specific color space. Consider passing it explicitly."
)
elif
isinstance
(
color_space
,
str
):
color_space
=
ColorSpace
.
from_str
(
color_space
.
upper
())
image
.
_metadata
.
update
(
dict
(
color_space
=
color_space
))
elif
not
isinstance
(
color_space
,
ColorSpace
):
raise
ValueError
image
.
color_space
=
color_space
return
image
@
classmethod
def
_to_tensor
(
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
],
device
:
Optional
[
torch
.
device
])
->
torch
.
Tensor
:
tensor
=
super
().
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
)
if
tensor
.
ndim
<
2
:
raise
ValueError
elif
tensor
.
ndim
==
2
:
tensor
=
tensor
.
unsqueeze
(
0
)
return
tensor
def
new_like
(
cls
,
other
:
Image
,
data
:
Any
,
*
,
color_space
:
Optional
[
Union
[
ColorSpace
,
str
]]
=
None
,
**
kwargs
:
Any
)
->
Image
:
return
super
().
new_like
(
other
,
data
,
color_space
=
color_space
if
color_space
is
not
None
else
other
.
color_space
,
**
kwargs
)
@
property
def
image_size
(
self
)
->
Tuple
[
int
,
int
]:
...
...
torchvision/prototype/features/_label.py
View file @
e836b3d8
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Sequence
,
cast
from
typing
import
Any
,
Optional
,
Sequence
,
cast
,
Union
import
torch
from
torchvision.prototype.utils._internal
import
apply_recursively
...
...
@@ -15,20 +15,32 @@ class Label(_Feature):
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
like
:
Optional
[
Label
]
=
None
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
Label
:
label
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
)
label
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
label
.
_metadata
.
update
(
dict
(
categories
=
categories
))
label
.
categories
=
categories
return
label
@
classmethod
def
from_category
(
cls
,
category
:
str
,
*
,
categories
:
Sequence
[
str
])
->
Label
:
return
cls
(
categories
.
index
(
category
),
categories
=
categories
)
def
new_like
(
cls
,
other
:
Label
,
data
:
Any
,
*
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
**
kwargs
:
Any
)
->
Label
:
return
super
().
new_like
(
other
,
data
,
categories
=
categories
if
categories
is
not
None
else
other
.
categories
,
**
kwargs
)
@
classmethod
def
from_category
(
cls
,
category
:
str
,
*
,
categories
:
Sequence
[
str
],
**
kwargs
:
Any
,
)
->
Label
:
return
cls
(
categories
.
index
(
category
),
categories
=
categories
,
**
kwargs
)
def
to_categories
(
self
)
->
Any
:
if
not
self
.
categories
:
...
...
@@ -44,16 +56,24 @@ class OneHotLabel(_Feature):
cls
,
data
:
Any
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
like
:
Optional
[
Label
]
=
None
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
False
,
)
->
OneHotLabel
:
one_hot_label
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
)
one_hot_label
=
super
().
__new__
(
cls
,
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
categories
is
not
None
and
len
(
categories
)
!=
one_hot_label
.
shape
[
-
1
]:
raise
ValueError
()
one_hot_label
.
_metadata
.
update
(
dict
(
categories
=
categories
))
one_hot_label
.
categories
=
categories
return
one_hot_label
@
classmethod
def
new_like
(
cls
,
other
:
OneHotLabel
,
data
:
Any
,
*
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
**
kwargs
:
Any
)
->
OneHotLabel
:
return
super
().
new_like
(
other
,
data
,
categories
=
categories
if
categories
is
not
None
else
other
.
categories
,
**
kwargs
)
torchvision/prototype/transforms/_geometry.py
View file @
e836b3d8
...
...
@@ -46,7 +46,7 @@ class Resize(Transform):
return
features
.
SegmentationMask
.
new_like
(
input
,
output
)
elif
isinstance
(
input
,
features
.
BoundingBox
):
output
=
F
.
resize_bounding_box
(
input
,
self
.
size
,
image_size
=
input
.
image_size
)
return
features
.
BoundingBox
.
new_like
(
input
,
output
,
image_size
=
self
.
size
)
return
features
.
BoundingBox
.
new_like
(
input
,
output
,
image_size
=
cast
(
Tuple
[
int
,
int
],
tuple
(
self
.
size
)
))
elif
isinstance
(
input
,
PIL
.
Image
.
Image
):
return
F
.
resize_image_pil
(
input
,
self
.
size
,
interpolation
=
self
.
interpolation
)
elif
isinstance
(
input
,
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment