Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
bf6a8dc2
Unverified
Commit
bf6a8dc2
authored
Aug 10, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 10, 2023
Browse files
Simplify _NO_WRAPPING_EXCEPTIONS (#7806)
parent
f2b6f43a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
18 deletions
+11
-18
test/test_datapoints.py
test/test_datapoints.py
+0
-1
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+11
-17
No files found.
test/test_datapoints.py
View file @
bf6a8dc2
...
@@ -209,4 +209,3 @@ def test_deepcopy(datapoint, requires_grad):
...
@@ -209,4 +209,3 @@ def test_deepcopy(datapoint, requires_grad):
assert
type
(
datapoint_deepcopied
)
is
type
(
datapoint
)
assert
type
(
datapoint_deepcopied
)
is
type
(
datapoint
)
assert
datapoint_deepcopied
.
requires_grad
is
requires_grad
assert
datapoint_deepcopied
.
requires_grad
is
requires_grad
assert
datapoint_deepcopied
.
is_leaf
torchvision/datapoints/_datapoint.py
View file @
bf6a8dc2
...
@@ -33,14 +33,9 @@ class Datapoint(torch.Tensor):
...
@@ -33,14 +33,9 @@ class Datapoint(torch.Tensor):
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
return
tensor
.
as_subclass
(
cls
)
return
tensor
.
as_subclass
(
cls
)
_NO_WRAPPING_EXCEPTIONS
=
{
# The ops in this set are those that should *preserve* the Datapoint type,
torch
.
Tensor
.
clone
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
# i.e. they are exceptions to the "no wrapping" rule.
torch
.
Tensor
.
to
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
_NO_WRAPPING_EXCEPTIONS
=
{
torch
.
Tensor
.
clone
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
detach
,
torch
.
Tensor
.
requires_grad_
}
torch
.
Tensor
.
detach
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch
.
Tensor
.
requires_grad_
:
lambda
cls
,
input
,
output
:
output
,
}
@
classmethod
@
classmethod
def
__torch_function__
(
def
__torch_function__
(
...
@@ -76,22 +71,21 @@ class Datapoint(torch.Tensor):
...
@@ -76,22 +71,21 @@ class Datapoint(torch.Tensor):
with
DisableTorchFunctionSubclass
():
with
DisableTorchFunctionSubclass
():
output
=
func
(
*
args
,
**
kwargs
or
dict
())
output
=
func
(
*
args
,
**
kwargs
or
dict
())
wrapper
=
cls
.
_NO_WRAPPING_EXCEPTIONS
.
get
(
func
)
if
func
in
cls
.
_NO_WRAPPING_EXCEPTIONS
and
isinstance
(
args
[
0
],
cls
):
#
Apart from `func` needing to be an exception, w
e also require the primary operand, i.e. `args[0]`, to be
#
W
e also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
# be wrapped into a `datapoints.Image`.
if
wrapper
and
isinstance
(
args
[
0
],
cls
):
return
cls
.
wrap_like
(
args
[
0
],
output
)
return
wrapper
(
cls
,
args
[
0
],
output
)
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
if
isinstance
(
output
,
cls
):
#
will retain the input type. Thus, we need to unwrap here.
#
DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
if
isinstance
(
output
,
cls
):
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return
output
.
as_subclass
(
torch
.
Tensor
)
return
output
.
as_subclass
(
torch
.
Tensor
)
return
output
return
output
def
_make_repr
(
self
,
**
kwargs
:
Any
)
->
str
:
def
_make_repr
(
self
,
**
kwargs
:
Any
)
->
str
:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment