Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f71c4308
"...transforms/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d4d20f01e191dbacd0a0e6c8a5db5062222753ba"
Unverified
Commit
f71c4308
authored
Jan 16, 2023
by
Philip Meier
Committed by
GitHub
Jan 16, 2023
Browse files
simplify dispatcher if-elif (#7084)
parent
69ae61a1
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
64 additions
and
107 deletions
+64
-107
mypy.ini
mypy.ini
+0
-1
torchvision/prototype/transforms/_type_conversion.py
torchvision/prototype/transforms/_type_conversion.py
+1
-1
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+3
-0
torchvision/prototype/transforms/functional/_augment.py
torchvision/prototype/transforms/functional/_augment.py
+3
-3
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+11
-28
torchvision/prototype/transforms/functional/_deprecated.py
torchvision/prototype/transforms/functional/_deprecated.py
+6
-4
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+15
-39
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+9
-17
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+4
-6
torchvision/prototype/transforms/functional/_temporal.py
torchvision/prototype/transforms/functional/_temporal.py
+3
-1
torchvision/prototype/transforms/functional/_utils.py
torchvision/prototype/transforms/functional/_utils.py
+8
-0
torchvision/prototype/transforms/utils.py
torchvision/prototype/transforms/utils.py
+1
-7
No files found.
mypy.ini
View file @
f71c4308
...
@@ -32,7 +32,6 @@ no_implicit_optional = True
...
@@ -32,7 +32,6 @@ no_implicit_optional = True
; warnings
; warnings
warn_unused_ignores
=
True
warn_unused_ignores
=
True
warn_return_any
=
True
; miscellaneous strictness flags
; miscellaneous strictness flags
allow_redefinition
=
True
allow_redefinition
=
True
...
...
torchvision/prototype/transforms/_type_conversion.py
View file @
f71c4308
...
@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
...
@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
datapoints
.
Image
:
)
->
datapoints
.
Image
:
return
F
.
to_image_tensor
(
inpt
)
# type: ignore[no-any-return]
return
F
.
to_image_tensor
(
inpt
)
class
ToImagePIL
(
Transform
):
class
ToImagePIL
(
Transform
):
...
...
torchvision/prototype/transforms/functional/__init__.py
View file @
f71c4308
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
._utils
import
is_simple_tensor
# usort: skip
from
._meta
import
(
from
._meta
import
(
clamp_bounding_box
,
clamp_bounding_box
,
convert_format_bounding_box
,
convert_format_bounding_box
,
...
...
torchvision/prototype/transforms/functional/_augment.py
View file @
f71c4308
...
@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints
...
@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
erase_image_tensor
(
def
erase_image_tensor
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
...
@@ -45,9 +47,7 @@ def erase(
...
@@ -45,9 +47,7 @@ def erase(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
_log_api_usage_once
(
erase
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
erase_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
output
=
erase_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
...
...
torchvision/prototype/transforms/functional/_color.py
View file @
f71c4308
...
@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value
...
@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
from
._utils
import
is_simple_tensor
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
...
@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
...
@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
_log_api_usage_once
(
adjust_brightness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_brightness
(
brightness_factor
=
brightness_factor
)
return
inpt
.
adjust_brightness
(
brightness_factor
=
brightness_factor
)
...
@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
...
@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
_log_api_usage_once
(
adjust_contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_contrast
(
contrast_factor
=
contrast_factor
)
return
inpt
.
adjust_contrast
(
contrast_factor
=
contrast_factor
)
...
@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
...
@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
_log_api_usage_once
(
adjust_hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_hue
(
hue_factor
=
hue_factor
)
return
inpt
.
adjust_hue
(
hue_factor
=
hue_factor
)
...
@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
...
@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
_log_api_usage_once
(
adjust_gamma
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_gamma
(
gamma
=
gamma
,
gain
=
gain
)
return
inpt
.
adjust_gamma
(
gamma
=
gamma
,
gain
=
gain
)
...
@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
...
@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
_log_api_usage_once
(
posterize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
posterize
(
bits
=
bits
)
return
inpt
.
posterize
(
bits
=
bits
)
...
@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
...
@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
_log_api_usage_once
(
solarize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
solarize
(
threshold
=
threshold
)
return
inpt
.
solarize
(
threshold
=
threshold
)
...
@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
_log_api_usage_once
(
autocontrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
autocontrast_image_tensor
(
inpt
)
return
autocontrast_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
autocontrast
()
return
inpt
.
autocontrast
()
...
@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
_log_api_usage_once
(
equalize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
equalize_image_tensor
(
inpt
)
return
equalize_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
equalize
()
return
inpt
.
equalize
()
...
@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
if
image
.
is_floating_point
():
return
1.0
-
image
# type: ignore[no-any-return]
return
1.0
-
image
elif
image
.
dtype
==
torch
.
uint8
:
elif
image
.
dtype
==
torch
.
uint8
:
return
image
.
bitwise_not
()
return
image
.
bitwise_not
()
else
:
# signed integer dtypes
else
:
# signed integer dtypes
...
@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
_log_api_usage_once
(
invert
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
invert_image_tensor
(
inpt
)
return
invert_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
invert
()
return
inpt
.
invert
()
...
...
torchvision/prototype/transforms/functional/_deprecated.py
View file @
f71c4308
...
@@ -7,6 +7,8 @@ import torch
...
@@ -7,6 +7,8 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.transforms
import
functional
as
_F
from
torchvision.transforms
import
functional
as
_F
from
._utils
import
is_simple_tensor
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
...
@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
...
@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def
rgb_to_grayscale
(
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
()
and
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
old_color_space
=
None
elif
isinstance
(
inpt
,
torch
.
Tensor
):
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
else
:
else
:
old_color_space
=
None
old_color_space
=
None
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
replacement
=
(
replacement
=
(
f
"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
f
"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
f71c4308
...
@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once
...
@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._utils
import
is_simple_tensor
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
return
image
.
flip
(
-
1
)
...
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
_log_api_usage_once
(
horizontal_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
horizontal_flip_image_tensor
(
inpt
)
return
horizontal_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
horizontal_flip
()
return
inpt
.
horizontal_flip
()
...
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
...
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
_log_api_usage_once
(
vertical_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
vertical_flip_image_tensor
(
inpt
)
return
vertical_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
vertical_flip
()
return
inpt
.
vertical_flip
()
...
@@ -241,9 +239,7 @@ def resize(
...
@@ -241,9 +239,7 @@ def resize(
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
_log_api_usage_once
(
resize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
resize
(
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
inpt
.
resize
(
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
...
@@ -744,9 +740,7 @@ def affine(
...
@@ -744,9 +740,7 @@ def affine(
_log_api_usage_once
(
affine
)
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
# TODO: consider deprecating integers from angle and shear on the future
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
affine_image_tensor
(
return
affine_image_tensor
(
inpt
,
inpt
,
angle
,
angle
,
...
@@ -929,9 +923,7 @@ def rotate(
...
@@ -929,9 +923,7 @@ def rotate(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
_log_api_usage_once
(
rotate
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
rotate
(
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
inpt
.
rotate
(
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
...
@@ -1139,9 +1131,7 @@ def pad(
...
@@ -1139,9 +1131,7 @@ def pad(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
_log_api_usage_once
(
pad
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
...
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
...
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
_log_api_usage_once
(
crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
crop
(
top
,
left
,
height
,
width
)
return
inpt
.
crop
(
top
,
left
,
height
,
width
)
...
@@ -1476,9 +1464,7 @@ def perspective(
...
@@ -1476,9 +1464,7 @@ def perspective(
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
_log_api_usage_once
(
perspective
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
perspective_image_tensor
(
return
perspective_image_tensor
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
)
...
@@ -1639,9 +1625,7 @@ def elastic(
...
@@ -1639,9 +1625,7 @@ def elastic(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
_log_api_usage_once
(
elastic
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
elastic
(
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
inpt
.
elastic
(
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
...
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
_log_api_usage_once
(
center_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
center_crop_image_tensor
(
inpt
,
output_size
)
return
center_crop_image_tensor
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
center_crop
(
output_size
)
return
inpt
.
center_crop
(
output_size
)
...
@@ -1850,9 +1832,7 @@ def resized_crop(
...
@@ -1850,9 +1832,7 @@ def resized_crop(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
_log_api_usage_once
(
resized_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
resized_crop_image_tensor
(
return
resized_crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
)
...
@@ -1935,9 +1915,7 @@ def five_crop(
...
@@ -1935,9 +1915,7 @@ def five_crop(
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
# `ten_crop`
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
five_crop_image_tensor
(
inpt
,
size
)
return
five_crop_image_tensor
(
inpt
,
size
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
five_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
output
=
five_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
...
@@ -1991,9 +1969,7 @@ def ten_crop(
...
@@ -1991,9 +1969,7 @@ def ten_crop(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
_log_api_usage_once
(
ten_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
ten_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
output
=
ten_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
f71c4308
...
@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value
...
@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
chw
=
list
(
image
.
shape
[
-
3
:])
...
@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
...
@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
_log_api_usage_once
(
get_dimensions
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
get_dimensions_image_tensor
(
inpt
)
return
get_dimensions_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
channels
=
inpt
.
num_channels
channels
=
inpt
.
num_channels
...
@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI
...
@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
_log_api_usage_once
(
get_num_channels
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
get_num_channels_image_tensor
(
inpt
)
return
get_num_channels_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
.
num_channels
return
inpt
.
num_channels
...
@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
...
@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_spatial_size
)
_log_api_usage_once
(
get_spatial_size
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
get_spatial_size_image_tensor
(
inpt
)
return
get_spatial_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
return
list
(
inpt
.
spatial_size
)
return
list
(
inpt
.
spatial_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_spatial_size_image_pil
(
inpt
)
# type: ignore[no-any-return]
return
get_spatial_size_image_pil
(
inpt
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
...
@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
_log_api_usage_once
(
get_num_frames
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_frames_video
(
inpt
)
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
elif
isinstance
(
inpt
,
datapoints
.
Video
):
return
inpt
.
num_frames
return
inpt
.
num_frames
...
@@ -336,9 +332,7 @@ def convert_color_space(
...
@@ -336,9 +332,7 @@ def convert_color_space(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
_log_api_usage_once
(
convert_color_space
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
old_color_space
is
None
:
if
old_color_space
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"In order to convert the color space of simple tensors, "
"In order to convert the color space of simple tensors, "
...
@@ -443,9 +437,7 @@ def convert_dtype(
...
@@ -443,9 +437,7 @@ def convert_dtype(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_dtype
)
_log_api_usage_once
(
convert_dtype
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
return
convert_dtype_image_tensor
(
inpt
,
dtype
)
return
convert_dtype_image_tensor
(
inpt
,
dtype
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
convert_dtype_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
)
output
=
convert_dtype_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
)
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
f71c4308
...
@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...
@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
.
.
utils
import
is_simple_tensor
from
.
_
utils
import
is_simple_tensor
def
normalize_image_tensor
(
def
normalize_image_tensor
(
...
@@ -61,9 +61,9 @@ def normalize(
...
@@ -61,9 +61,9 @@ def normalize(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
_log_api_usage_once
(
normalize
)
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
el
se
:
el
if
not
is_simple_tensor
(
inpt
)
:
raise
TypeError
(
raise
TypeError
(
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"but got
{
type
(
inpt
)
}
instead."
f
"but got
{
type
(
inpt
)
}
instead."
...
@@ -175,9 +175,7 @@ def gaussian_blur(
...
@@ -175,9 +175,7 @@ def gaussian_blur(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
_log_api_usage_once
(
gaussian_blur
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
gaussian_blur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
inpt
.
gaussian_blur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
...
torchvision/prototype/transforms/functional/_temporal.py
View file @
f71c4308
...
@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints
...
@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
@@ -18,7 +20,7 @@ def uniform_temporal_subsample(
...
@@ -18,7 +20,7 @@ def uniform_temporal_subsample(
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
_log_api_usage_once
(
uniform_temporal_subsample
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
elif
isinstance
(
inpt
,
datapoints
.
Video
):
if
temporal_dim
!=
-
4
and
inpt
.
ndim
-
4
!=
temporal_dim
:
if
temporal_dim
!=
-
4
and
inpt
.
ndim
-
4
!=
temporal_dim
:
...
...
torchvision/prototype/transforms/functional/_utils.py
0 → 100644
View file @
f71c4308
from
typing
import
Any
import
torch
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
torchvision/prototype/transforms/utils.py
View file @
f71c4308
...
@@ -3,16 +3,10 @@ from __future__ import annotations
...
@@ -3,16 +3,10 @@ from __future__ import annotations
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
import
PIL.Image
import
PIL.Image
import
torch
from
torchvision._utils
import
sequence_to_str
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
,
is_simple_tensor
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
:
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment