Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f71c4308
Unverified
Commit
f71c4308
authored
Jan 16, 2023
by
Philip Meier
Committed by
GitHub
Jan 16, 2023
Browse files
simplify dispatcher if-elif (#7084)
parent
69ae61a1
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
64 additions
and
107 deletions
+64
-107
mypy.ini
mypy.ini
+0
-1
torchvision/prototype/transforms/_type_conversion.py
torchvision/prototype/transforms/_type_conversion.py
+1
-1
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+3
-0
torchvision/prototype/transforms/functional/_augment.py
torchvision/prototype/transforms/functional/_augment.py
+3
-3
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+11
-28
torchvision/prototype/transforms/functional/_deprecated.py
torchvision/prototype/transforms/functional/_deprecated.py
+6
-4
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+15
-39
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+9
-17
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+4
-6
torchvision/prototype/transforms/functional/_temporal.py
torchvision/prototype/transforms/functional/_temporal.py
+3
-1
torchvision/prototype/transforms/functional/_utils.py
torchvision/prototype/transforms/functional/_utils.py
+8
-0
torchvision/prototype/transforms/utils.py
torchvision/prototype/transforms/utils.py
+1
-7
No files found.
mypy.ini
View file @
f71c4308
...
@@ -32,7 +32,6 @@ no_implicit_optional = True
...
@@ -32,7 +32,6 @@ no_implicit_optional = True
; warnings
; warnings
warn_unused_ignores
=
True
warn_unused_ignores
=
True
warn_return_any
=
True
; miscellaneous strictness flags
; miscellaneous strictness flags
allow_redefinition
=
True
allow_redefinition
=
True
...
...
torchvision/prototype/transforms/_type_conversion.py
View file @
f71c4308
...
@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
...
@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
datapoints
.
Image
:
)
->
datapoints
.
Image
:
return
F
.
to_image_tensor
(
inpt
)
# type: ignore[no-any-return]
return
F
.
to_image_tensor
(
inpt
)
class
ToImagePIL
(
Transform
):
class
ToImagePIL
(
Transform
):
...
...
torchvision/prototype/transforms/functional/__init__.py
View file @
f71c4308
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
._utils
import
is_simple_tensor
# usort: skip
from
._meta
import
(
from
._meta
import
(
clamp_bounding_box
,
clamp_bounding_box
,
convert_format_bounding_box
,
convert_format_bounding_box
,
...
...
torchvision/prototype/transforms/functional/_augment.py
View file @
f71c4308
...
@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints
...
@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
erase_image_tensor
(
def
erase_image_tensor
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
...
@@ -45,9 +47,7 @@ def erase(
...
@@ -45,9 +47,7 @@ def erase(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
_log_api_usage_once
(
erase
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
erase_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
output
=
erase_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
...
...
torchvision/prototype/transforms/functional/_color.py
View file @
f71c4308
...
@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value
...
@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
from
._utils
import
is_simple_tensor
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
...
@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
...
@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
_log_api_usage_once
(
adjust_brightness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_brightness
(
brightness_factor
=
brightness_factor
)
return
inpt
.
adjust_brightness
(
brightness_factor
=
brightness_factor
)
...
@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
...
@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
_log_api_usage_once
(
adjust_contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_contrast
(
contrast_factor
=
contrast_factor
)
return
inpt
.
adjust_contrast
(
contrast_factor
=
contrast_factor
)
...
@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
...
@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
_log_api_usage_once
(
adjust_hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_hue
(
hue_factor
=
hue_factor
)
return
inpt
.
adjust_hue
(
hue_factor
=
hue_factor
)
...
@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
...
@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
_log_api_usage_once
(
adjust_gamma
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_gamma
(
gamma
=
gamma
,
gain
=
gain
)
return
inpt
.
adjust_gamma
(
gamma
=
gamma
,
gain
=
gain
)
...
@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
...
@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
_log_api_usage_once
(
posterize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
posterize
(
bits
=
bits
)
return
inpt
.
posterize
(
bits
=
bits
)
...
@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
...
@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
_log_api_usage_once
(
solarize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
solarize
(
threshold
=
threshold
)
return
inpt
.
solarize
(
threshold
=
threshold
)
...
@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
_log_api_usage_once
(
autocontrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
autocontrast_image_tensor
(
inpt
)
return
autocontrast_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
autocontrast
()
return
inpt
.
autocontrast
()
...
@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
_log_api_usage_once
(
equalize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
equalize_image_tensor
(
inpt
)
return
equalize_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
equalize
()
return
inpt
.
equalize
()
...
@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
if
image
.
is_floating_point
():
return
1.0
-
image
# type: ignore[no-any-return]
return
1.0
-
image
elif
image
.
dtype
==
torch
.
uint8
:
elif
image
.
dtype
==
torch
.
uint8
:
return
image
.
bitwise_not
()
return
image
.
bitwise_not
()
else
:
# signed integer dtypes
else
:
# signed integer dtypes
...
@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
_log_api_usage_once
(
invert
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
invert_image_tensor
(
inpt
)
return
invert_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
invert
()
return
inpt
.
invert
()
...
...
torchvision/prototype/transforms/functional/_deprecated.py
View file @
f71c4308
...
@@ -7,6 +7,8 @@ import torch
...
@@ -7,6 +7,8 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.transforms
import
functional
as
_F
from
torchvision.transforms
import
functional
as
_F
from
._utils
import
is_simple_tensor
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
...
@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
...
@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def
rgb_to_grayscale
(
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
()
and
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
old_color_space
=
None
elif
isinstance
(
inpt
,
torch
.
Tensor
):
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
else
:
else
:
old_color_space
=
None
old_color_space
=
None
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
replacement
=
(
replacement
=
(
f
"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
f
"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
f71c4308
...
@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once
...
@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._utils
import
is_simple_tensor
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
return
image
.
flip
(
-
1
)
...
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
_log_api_usage_once
(
horizontal_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
horizontal_flip_image_tensor
(
inpt
)
return
horizontal_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
horizontal_flip
()
return
inpt
.
horizontal_flip
()
...
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
_log_api_usage_once
(
vertical_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
vertical_flip_image_tensor
(
inpt
)
return
vertical_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
vertical_flip
()
return
inpt
.
vertical_flip
()
...
@@ -241,9 +239,7 @@ def resize(
...
@@ -241,9 +239,7 @@ def resize(
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
_log_api_usage_once
(
resize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
resize
(
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
inpt
.
resize
(
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
...
@@ -744,9 +740,7 @@ def affine(
...
@@ -744,9 +740,7 @@ def affine(
_log_api_usage_once
(
affine
)
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
# TODO: consider deprecating integers from angle and shear on the future
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
affine_image_tensor
(
return
affine_image_tensor
(
inpt
,
inpt
,
angle
,
angle
,
...
@@ -929,9 +923,7 @@ def rotate(
...
@@ -929,9 +923,7 @@ def rotate(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
_log_api_usage_once
(
rotate
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
rotate
(
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
inpt
.
rotate
(
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
...
@@ -1139,9 +1131,7 @@ def pad(
...
@@ -1139,9 +1131,7 @@ def pad(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
_log_api_usage_once
(
pad
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
...
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
...
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
_log_api_usage_once
(
crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
crop
(
top
,
left
,
height
,
width
)
return
inpt
.
crop
(
top
,
left
,
height
,
width
)
...
@@ -1476,9 +1464,7 @@ def perspective(
...
@@ -1476,9 +1464,7 @@ def perspective(
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
_log_api_usage_once
(
perspective
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
perspective_image_tensor
(
return
perspective_image_tensor
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
)
...
@@ -1639,9 +1625,7 @@ def elastic(
...
@@ -1639,9 +1625,7 @@ def elastic(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
_log_api_usage_once
(
elastic
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
elastic
(
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
inpt
.
elastic
(
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
...
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
_log_api_usage_once
(
center_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
center_crop_image_tensor
(
inpt
,
output_size
)
return
center_crop_image_tensor
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
center_crop
(
output_size
)
return
inpt
.
center_crop
(
output_size
)
...
@@ -1850,9 +1832,7 @@ def resized_crop(
...
@@ -1850,9 +1832,7 @@ def resized_crop(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
_log_api_usage_once
(
resized_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
resized_crop_image_tensor
(
return
resized_crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
)
...
@@ -1935,9 +1915,7 @@ def five_crop(
...
@@ -1935,9 +1915,7 @@ def five_crop(
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
# `ten_crop`
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
five_crop_image_tensor
(
inpt
,
size
)
return
five_crop_image_tensor
(
inpt
,
size
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
five_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
output
=
five_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
...
@@ -1991,9 +1969,7 @@ def ten_crop(
...
@@ -1991,9 +1969,7 @@ def ten_crop(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
_log_api_usage_once
(
ten_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
ten_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
output
=
ten_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
f71c4308
...
@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value
...
@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
chw
=
list
(
image
.
shape
[
-
3
:])
...
@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
...
@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
_log_api_usage_once
(
get_dimensions
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
get_dimensions_image_tensor
(
inpt
)
return
get_dimensions_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
channels
=
inpt
.
num_channels
channels
=
inpt
.
num_channels
...
@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI
...
@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
_log_api_usage_once
(
get_num_channels
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
get_num_channels_image_tensor
(
inpt
)
return
get_num_channels_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
.
num_channels
return
inpt
.
num_channels
...
@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
...
@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_spatial_size
)
_log_api_usage_once
(
get_spatial_size
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
get_spatial_size_image_tensor
(
inpt
)
return
get_spatial_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
return
list
(
inpt
.
spatial_size
)
return
list
(
inpt
.
spatial_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_spatial_size_image_pil
(
inpt
)
# type: ignore[no-any-return]
return
get_spatial_size_image_pil
(
inpt
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
...
@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
_log_api_usage_once
(
get_num_frames
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_frames_video
(
inpt
)
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
elif
isinstance
(
inpt
,
datapoints
.
Video
):
return
inpt
.
num_frames
return
inpt
.
num_frames
...
@@ -336,9 +332,7 @@ def convert_color_space(
...
@@ -336,9 +332,7 @@ def convert_color_space(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
_log_api_usage_once
(
convert_color_space
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
old_color_space
is
None
:
if
old_color_space
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"In order to convert the color space of simple tensors, "
"In order to convert the color space of simple tensors, "
...
@@ -443,9 +437,7 @@ def convert_dtype(
...
@@ -443,9 +437,7 @@ def convert_dtype(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_dtype
)
_log_api_usage_once
(
convert_dtype
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
convert_dtype_image_tensor
(
inpt
,
dtype
)
return
convert_dtype_image_tensor
(
inpt
,
dtype
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
convert_dtype_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
)
output
=
convert_dtype_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
)
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
f71c4308
...
@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...
@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
.
.
utils
import
is_simple_tensor
from
.
_
utils
import
is_simple_tensor
def
normalize_image_tensor
(
def
normalize_image_tensor
(
...
@@ -61,9 +61,9 @@ def normalize(
...
@@ -61,9 +61,9 @@ def normalize(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
_log_api_usage_once
(
normalize
)
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
el
se
:
el
if
not
is_simple_tensor
(
inpt
)
:
raise
TypeError
(
raise
TypeError
(
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"but got
{
type
(
inpt
)
}
instead."
f
"but got
{
type
(
inpt
)
}
instead."
...
@@ -175,9 +175,7 @@ def gaussian_blur(
...
@@ -175,9 +175,7 @@ def gaussian_blur(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
_log_api_usage_once
(
gaussian_blur
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
gaussian_blur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
inpt
.
gaussian_blur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
...
torchvision/prototype/transforms/functional/_temporal.py
View file @
f71c4308
...
@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints
...
@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
@@ -18,7 +20,7 @@ def uniform_temporal_subsample(
...
@@ -18,7 +20,7 @@ def uniform_temporal_subsample(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
_log_api_usage_once
(
uniform_temporal_subsample
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
elif
isinstance
(
inpt
,
datapoints
.
Video
):
if
temporal_dim
!=
-
4
and
inpt
.
ndim
-
4
!=
temporal_dim
:
if
temporal_dim
!=
-
4
and
inpt
.
ndim
-
4
!=
temporal_dim
:
...
...
torchvision/prototype/transforms/functional/_utils.py
0 → 100644
View file @
f71c4308
from
typing
import
Any
import
torch
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
torchvision/prototype/transforms/utils.py
View file @
f71c4308
...
@@ -3,16 +3,10 @@ from __future__ import annotations
...
@@ -3,16 +3,10 @@ from __future__ import annotations
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
import
PIL.Image
import
PIL.Image
import
torch
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
,
is_simple_tensor
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
:
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment