Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d5f4cc38
Unverified
Commit
d5f4cc38
authored
Aug 30, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 30, 2023
Browse files
Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)
parent
b9447fdd
Changes
85
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
23 deletions
+23
-23
torchvision/tv_tensors/_image.py
torchvision/tv_tensors/_image.py
+2
-2
torchvision/tv_tensors/_mask.py
torchvision/tv_tensors/_mask.py
+2
-2
torchvision/tv_tensors/_torch_function_helpers.py
torchvision/tv_tensors/_torch_function_helpers.py
+7
-7
torchvision/tv_tensors/_tv_tensor.py
torchvision/tv_tensors/_tv_tensor.py
+10
-10
torchvision/tv_tensors/_video.py
torchvision/tv_tensors/_video.py
+2
-2
No files found.
torchvision/
datapoint
s/_image.py
→
torchvision/
tv_tensor
s/_image.py
View file @
d5f4cc38
...
...
@@ -5,10 +5,10 @@ from typing import Any, Optional, Union
import
PIL.Image
import
torch
from
._
datapoint
import
Datapoint
from
._
tv_tensor
import
TVTensor
class
Image
(
Datapoint
):
class
Image
(
TVTensor
):
"""[BETA] :class:`torch.Tensor` subclass for images.
.. note::
...
...
torchvision/
datapoint
s/_mask.py
→
torchvision/
tv_tensor
s/_mask.py
View file @
d5f4cc38
...
...
@@ -5,10 +5,10 @@ from typing import Any, Optional, Union
import
PIL.Image
import
torch
from
._
datapoint
import
Datapoint
from
._
tv_tensor
import
TVTensor
class
Mask
(
Datapoint
):
class
Mask
(
TVTensor
):
"""[BETA] :class:`torch.Tensor` subclass for segmentation and detection masks.
Args:
...
...
torchvision/
datapoint
s/_torch_function_helpers.py
→
torchvision/
tv_tensor
s/_torch_function_helpers.py
View file @
d5f4cc38
...
...
@@ -16,7 +16,7 @@ class _ReturnTypeCM:
def
set_return_type
(
return_type
:
str
):
"""[BETA] Set the return type of torch operations on
datapoint
s.
"""[BETA] Set the return type of torch operations on
tv_tensor
s.
This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
...
...
@@ -33,28 +33,28 @@ def set_return_type(return_type: str):
.. code:: python
img =
datapoint
s.Image(torch.rand(3, 5, 5))
img =
tv_tensor
s.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)
set_return_type("
datapoint
s")
set_return_type("
tv_tensor
s")
img + 2 # This is an Image
or as a context manager to restrict the scope:
.. code:: python
img =
datapoint
s.Image(torch.rand(3, 5, 5))
img =
tv_tensor
s.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("
datapoint
s"):
with set_return_type("
tv_tensor
s"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Args:
return_type (str): Can be "
datapoint
" or "tensor". Default is "tensor".
return_type (str): Can be "
tv_tensor
" or "tensor". Default is "tensor".
"""
global
_TORCHFUNCTION_SUBCLASS
to_restore
=
_TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"
datapoint
"
:
True
}[
return_type
.
lower
()]
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"
tv_tensor
"
:
True
}[
return_type
.
lower
()]
return
_ReturnTypeCM
(
to_restore
)
...
...
torchvision/
datapoints/_datapoint
.py
→
torchvision/
tv_tensors/_tv_tensor
.py
View file @
d5f4cc38
...
...
@@ -6,18 +6,18 @@ import torch
from
torch._C
import
DisableTorchFunctionSubclass
from
torch.types
import
_device
,
_dtype
,
_size
from
torchvision.
datapoint
s._torch_function_helpers
import
_FORCE_TORCHFUNCTION_SUBCLASS
,
_must_return_subclass
from
torchvision.
tv_tensor
s._torch_function_helpers
import
_FORCE_TORCHFUNCTION_SUBCLASS
,
_must_return_subclass
D
=
TypeVar
(
"D"
,
bound
=
"
Datapoint
"
)
D
=
TypeVar
(
"D"
,
bound
=
"
TVTensor
"
)
class
Datapoint
(
torch
.
Tensor
):
"""[Beta] Base class for all
datapoint
s.
class
TVTensor
(
torch
.
Tensor
):
"""[Beta] Base class for all
tv_tensor
s.
You probably don't want to use this class unless you're defining your own
custom
Datapoint
s. See
:ref:`sphx_glr_auto_examples_transforms_plot_custom_
datapoint
s.py` for details.
custom
TVTensor
s. See
:ref:`sphx_glr_auto_examples_transforms_plot_custom_
tv_tensor
s.py` for details.
"""
@
staticmethod
...
...
@@ -62,9 +62,9 @@ class Datapoint(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
Why do we override this? Because the base implementation in torch.Tensor would preserve the
Datapoint
type
Why do we override this? Because the base implementation in torch.Tensor would preserve the
TVTensor
type
of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
"
Datapoint
s FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
"
TVTensor
s FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
"""
...
...
@@ -79,7 +79,7 @@ class Datapoint(torch.Tensor):
must_return_subclass
=
_must_return_subclass
()
if
must_return_subclass
or
(
func
in
_FORCE_TORCHFUNCTION_SUBCLASS
and
isinstance
(
args
[
0
],
cls
)):
# If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
# in test_to_
datapoint
_reference().
# in test_to_
tv_tensor
_reference().
# The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
# the computation by walking the MRO upwards. For example,
# `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
...
...
@@ -89,7 +89,7 @@ class Datapoint(torch.Tensor):
if
not
must_return_subclass
and
isinstance
(
output
,
cls
):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a
Datapoint
. Thus, we need to manually unwrap.
# so for those, the output is still a
TVTensor
. Thus, we need to manually unwrap.
return
output
.
as_subclass
(
torch
.
Tensor
)
return
output
...
...
torchvision/
datapoint
s/_video.py
→
torchvision/
tv_tensor
s/_video.py
View file @
d5f4cc38
...
...
@@ -4,10 +4,10 @@ from typing import Any, Optional, Union
import
torch
from
._
datapoint
import
Datapoint
from
._
tv_tensor
import
TVTensor
class
Video
(
Datapoint
):
class
Video
(
TVTensor
):
"""[BETA] :class:`torch.Tensor` subclass for videos.
Args:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment