Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a893f313
Unverified
Commit
a893f313
authored
Aug 02, 2023
by
Philip Meier
Committed by
GitHub
Aug 02, 2023
Browse files
refactor Datapoint dispatch mechanism (#7747)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
16d62e30
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
287 additions
and
125 deletions
+287
-125
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+73
-72
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+63
-38
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+19
-12
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+132
-3
No files found.
torchvision/transforms/v2/functional/_meta.py
View file @
a893f313
...
...
@@ -8,9 +8,29 @@ from torchvision.transforms import _functional_pil as _FP
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
_register_unsupported_type
,
is_simple_tensor
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_dimensions_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_dimensions
,
type
(
inpt
))
return
kernel
(
inpt
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_dimensions_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
ndims
=
len
(
chw
)
...
...
@@ -26,31 +46,31 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil
=
_FP
.
get_dimensions
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_dimensions_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_dimensions_image_tensor
(
video
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
List
[
int
]:
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_
dimension
s
)
_log_api_usage_once
(
get_
num_channel
s
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_dimensions_image_tensor
(
inpt
)
for
typ
,
get_size_fn
in
{
datapoints
.
Image
:
get_dimensions_image_tensor
,
datapoints
.
Video
:
get_dimensions_video
,
PIL
.
Image
.
Image
:
get_dimensions_image_pil
,
}.
items
():
if
isinstance
(
inpt
,
typ
):
return
get_size_fn
(
inpt
)
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
return
get_num_channels_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_num_channels
,
type
(
inpt
))
return
kernel
(
inpt
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_num_channels_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_num_channels_image_tensor
(
image
:
torch
.
Tensor
)
->
int
:
chw
=
image
.
shape
[
-
3
:]
ndims
=
len
(
chw
)
...
...
@@ -65,36 +85,35 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
get_num_channels_image_pil
=
_FP
.
get_image_num_channels
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_num_channels_video
(
video
:
torch
.
Tensor
)
->
int
:
return
get_num_channels_image_tensor
(
video
)
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_channels_image_tensor
(
inpt
)
for
typ
,
get_size_fn
in
{
datapoints
.
Image
:
get_num_channels_image_tensor
,
datapoints
.
Video
:
get_num_channels_video
,
PIL
.
Image
.
Image
:
get_num_channels_image_pil
,
}.
items
():
if
isinstance
(
inpt
,
typ
):
return
get_size_fn
(
inpt
)
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# deprecating the old names.
get_image_num_channels
=
get_num_channels
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_size
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_size
,
type
(
inpt
))
return
kernel
(
inpt
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_size_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_size_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
hw
=
list
(
image
.
shape
[
-
2
:])
ndims
=
len
(
hw
)
...
...
@@ -110,59 +129,41 @@ def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return
[
height
,
width
]
@
_register_kernel_internal
(
get_size
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_size_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image_tensor
(
video
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
get_size_mask
(
mask
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image_tensor
(
mask
)
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
get_size
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
get_size_bounding_boxes
(
bounding_box
:
datapoints
.
BoundingBoxes
)
->
List
[
int
]:
return
list
(
bounding_box
.
canvas_size
)
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_size
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_size_image_tensor
(
inpt
)
# TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with
# https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have
# a method on the datapoint class
for
typ
,
get_size_fn
in
{
datapoints
.
Image
:
get_size_image_tensor
,
datapoints
.
BoundingBoxes
:
get_size_bounding_boxes
,
datapoints
.
Mask
:
get_size_mask
,
datapoints
.
Video
:
get_size_video
,
PIL
.
Image
.
Image
:
get_size_image_pil
,
}.
items
():
if
isinstance
(
inpt
,
typ
):
return
get_size_fn
(
inpt
)
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
def
get_num_frames_video
(
video
:
torch
.
Tensor
)
->
int
:
return
video
.
shape
[
-
4
]
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
datapoints
.
_VideoTypeJIT
)
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_num_frames
,
type
(
inpt
))
return
kernel
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a `Video` datapoint, but got
{
type
(
inpt
)
}
instead."
)
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
get_num_frames
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_num_frames_video
(
video
:
torch
.
Tensor
)
->
int
:
return
video
.
shape
[
-
4
]
def
_xywh_to_xyxy
(
xywh
:
torch
.
Tensor
,
inplace
:
bool
)
->
torch
.
Tensor
:
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
a893f313
...
...
@@ -11,9 +11,37 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
from
._utils
import
(
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
_register_unsupported_type
,
is_simple_tensor
,
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
normalize
(
inpt
:
Union
[
datapoints
.
_TensorImageTypeJIT
,
datapoints
.
_TensorVideoTypeJIT
],
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
normalize_image_tensor
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
normalize
,
type
(
inpt
))
return
kernel
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or any TorchVision datapoint, but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
normalize
,
datapoints
.
Image
)
def
normalize_image_tensor
(
image
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -49,25 +77,29 @@ def normalize_image_tensor(
return
image
.
div_
(
std
)
@
_register_kernel_internal
(
normalize
,
datapoints
.
Video
)
def
normalize_video
(
video
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
normalize_image_tensor
(
video
,
mean
,
std
,
inplace
=
inplace
)
def
normalize
(
inpt
:
Union
[
datapoints
.
_TensorImageTypeJIT
,
datapoints
.
_TensorVideoTypeJIT
],
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
gaussian_blur
(
inpt
:
datapoints
.
_InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
_log_api_usage_once
(
gaussian_blur
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
normalize_image_tensor
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
.
normalize
(
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
gaussian_blur
,
type
(
inpt
))
return
kernel
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
gaussian_blur_image_pil
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
...
...
@@ -87,6 +119,7 @@ def _get_gaussian_kernel2d(
return
kernel2d
@
_register_kernel_internal
(
gaussian_blur
,
datapoints
.
Image
)
def
gaussian_blur_image_tensor
(
image
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -160,28 +193,27 @@ def gaussian_blur_image_pil(
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
@
_register_kernel_internal
(
gaussian_blur
,
datapoints
.
Video
)
def
gaussian_blur_video
(
video
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
return
gaussian_blur_image_tensor
(
video
,
kernel_size
,
sigma
)
def
gaussian_blur
(
inpt
:
datapoints
.
_InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
Non
e
def
to_dtype
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
Fals
e
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
_log_api_usage_once
(
to_dtype
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
gaussian_blur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
gaussian_blur_image_pil
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
to_dtype_image_tensor
(
inpt
,
dtype
,
scale
=
scale
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
to_dtype
,
type
(
inpt
))
return
kernel
(
inpt
,
dtype
,
scale
=
scale
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor or any TorchVision datapoint, but got
{
type
(
inpt
)
}
instead."
)
...
...
@@ -200,6 +232,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise
TypeError
(
f
"Number of value bits is only defined for integer dtypes, but got
{
dtype
}
."
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Image
)
def
to_dtype_image_tensor
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
image
.
dtype
==
dtype
:
...
...
@@ -257,23 +290,15 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32)
return
to_dtype_image_tensor
(
image
,
dtype
=
dtype
,
scale
=
True
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Video
)
def
to_dtype_video
(
video
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
return
to_dtype_image_tensor
(
video
,
dtype
,
scale
=
scale
)
def
to_dtype
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
to_dtype
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
to_dtype_image_tensor
(
inpt
,
dtype
,
scale
=
scale
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
to_dtype_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
,
scale
=
scale
)
return
datapoints
.
Image
.
wrap_like
(
inpt
,
output
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
to_dtype_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
,
scale
=
scale
)
return
datapoints
.
Video
.
wrap_like
(
inpt
,
output
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
to
(
dtype
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a datapoint, but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_to_dtype_tensor_dispatch
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
return
inpt
.
to
(
dtype
)
torchvision/transforms/v2/functional/_temporal.py
View file @
a893f313
import
PIL.Image
import
torch
from
torchvision
import
datapoints
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max
=
video
.
shape
[
-
4
]
-
1
indices
=
torch
.
linspace
(
0
,
t_max
,
num_samples
,
device
=
video
.
device
).
long
()
return
torch
.
index_select
(
video
,
-
4
,
indices
)
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
is_simple_tensor
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
_VideoTypeJIT
,
num_samples
:
int
)
->
datapoints
.
_VideoTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
uniform_temporal_subsample
_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
num_samples
)
return
datapoints
.
Video
.
wrap_like
(
inpt
,
output
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
uniform_temporal_subsample
,
type
(
inpt
)
)
return
kernel
(
inpt
,
num_samples
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a `Video` datapoint, but got
{
type
(
inpt
)
}
instead."
)
raise
TypeError
(
f
"Input can either be a plain tensor or any TorchVision datapoint, but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
uniform_temporal_subsample
,
datapoints
.
Video
)
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max
=
video
.
shape
[
-
4
]
-
1
indices
=
torch
.
linspace
(
0
,
t_max
,
num_samples
,
device
=
video
.
device
).
long
()
return
torch
.
index_select
(
video
,
-
4
,
indices
)
torchvision/transforms/v2/functional/_utils.py
View file @
a893f313
from
typing
import
Any
import
functools
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
Type
import
torch
from
torchvision
.datapoints._datapoint
import
D
atapoint
from
torchvision
import
d
atapoint
s
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
# {dispatcher: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY
:
Dict
[
Callable
,
Dict
[
Type
,
Callable
]]
=
{}
def
_kernel_datapoint_wrapper
(
kernel
):
@
functools
.
wraps
(
kernel
)
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
return
type
(
inpt
).
wrap_like
(
inpt
,
output
)
return
wrapper
def
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
*
,
datapoint_wrapper
=
True
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
if
datapoint_cls
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
.
__name__
}
' already has a kernel registered for type '
{
datapoint_cls
.
__name__
}
'."
)
def
decorator
(
kernel
):
registry
[
datapoint_cls
]
=
_kernel_datapoint_wrapper
(
kernel
)
if
datapoint_wrapper
else
kernel
return
kernel
return
decorator
def
register_kernel
(
dispatcher
,
datapoint_cls
):
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
def
_get_kernel
(
dispatcher
,
datapoint_cls
):
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for dispatcher '
{
dispatcher
.
__name__
}
'."
)
if
datapoint_cls
in
registry
:
return
registry
[
datapoint_cls
]
for
registered_cls
,
kernel
in
registry
.
items
():
if
issubclass
(
datapoint_cls
,
registered_cls
):
return
kernel
return
_noop
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def
_register_explicit_noop
(
*
datapoints_classes
,
warn_passthrough
=
False
):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
def
decorator
(
dispatcher
):
for
cls
in
datapoints_classes
:
msg
=
(
f
"F.
{
dispatcher
.
__name__
}
is currently passing through inputs of type datapoints.
{
cls
.
__name__
}
. "
f
"This will likely change in the future."
)
register_kernel
(
dispatcher
,
cls
)(
functools
.
partial
(
_noop
,
__msg__
=
msg
if
warn_passthrough
else
None
))
return
dispatcher
return
decorator
def
_noop
(
inpt
,
*
args
,
__msg__
=
None
,
**
kwargs
):
if
__msg__
:
warnings
.
warn
(
__msg__
,
UserWarning
,
stacklevel
=
2
)
return
inpt
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def
_register_unsupported_type
(
*
datapoints_classes
):
def
kernel
(
inpt
,
*
args
,
__dispatcher_name__
,
**
kwargs
):
raise
TypeError
(
f
"F.
{
__dispatcher_name__
}
does not support inputs of type
{
type
(
inpt
)
}
."
)
def
decorator
(
dispatcher
):
for
cls
in
datapoints_classes
:
register_kernel
(
dispatcher
,
cls
)(
functools
.
partial
(
kernel
,
__dispatcher_name__
=
dispatcher
.
__name__
))
return
dispatcher
return
decorator
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# TODO: decide if we want that
def
_register_five_ten_crop_kernel
(
dispatcher
,
datapoint_cls
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
if
datapoint_cls
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
.
__name__
}
' already has a kernel registered for type '
{
datapoint_cls
.
__name__
}
'."
)
def
wrap
(
kernel
):
@
functools
.
wraps
(
kernel
)
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
output
=
kernel
(
inpt
,
*
args
,
**
kwargs
)
container_type
=
type
(
output
)
return
container_type
(
type
(
inpt
).
wrap_like
(
inpt
,
o
)
for
o
in
output
)
return
wrapper
def
decorator
(
kernel
):
registry
[
datapoint_cls
]
=
wrap
(
kernel
)
return
kernel
return
decorator
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment