Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f71c4308
Unverified
Commit
f71c4308
authored
Jan 16, 2023
by
Philip Meier
Committed by
GitHub
Jan 16, 2023
Browse files
simplify dispatcher if-elif (#7084)
parent
69ae61a1
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
64 additions
and
107 deletions
+64
-107
mypy.ini
mypy.ini
+0
-1
torchvision/prototype/transforms/_type_conversion.py
torchvision/prototype/transforms/_type_conversion.py
+1
-1
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+3
-0
torchvision/prototype/transforms/functional/_augment.py
torchvision/prototype/transforms/functional/_augment.py
+3
-3
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+11
-28
torchvision/prototype/transforms/functional/_deprecated.py
torchvision/prototype/transforms/functional/_deprecated.py
+6
-4
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+15
-39
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+9
-17
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+4
-6
torchvision/prototype/transforms/functional/_temporal.py
torchvision/prototype/transforms/functional/_temporal.py
+3
-1
torchvision/prototype/transforms/functional/_utils.py
torchvision/prototype/transforms/functional/_utils.py
+8
-0
torchvision/prototype/transforms/utils.py
torchvision/prototype/transforms/utils.py
+1
-7
No files found.
mypy.ini
View file @
f71c4308
...
...
@@ -32,7 +32,6 @@ no_implicit_optional = True
; warnings
warn_unused_ignores
=
True
warn_return_any
=
True
; miscellaneous strictness flags
allow_redefinition
=
True
...
...
torchvision/prototype/transforms/_type_conversion.py
View file @
f71c4308
...
...
@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
datapoints
.
Image
:
return
F
.
to_image_tensor
(
inpt
)
# type: ignore[no-any-return]
return
F
.
to_image_tensor
(
inpt
)
class
ToImagePIL
(
Transform
):
...
...
torchvision/prototype/transforms/functional/__init__.py
View file @
f71c4308
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
._utils
import
is_simple_tensor
# usort: skip
from
._meta
import
(
clamp_bounding_box
,
convert_format_bounding_box
,
...
...
torchvision/prototype/transforms/functional/_augment.py
View file @
f71c4308
...
...
@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
erase_image_tensor
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
...
...
@@ -45,9 +47,7 @@ def erase(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
erase_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
...
...
torchvision/prototype/transforms/functional/_color.py
View file @
f71c4308
...
...
@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
from
._utils
import
is_simple_tensor
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
...
...
@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_brightness
(
brightness_factor
=
brightness_factor
)
...
...
@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_contrast
(
contrast_factor
=
contrast_factor
)
...
...
@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_hue
(
hue_factor
=
hue_factor
)
...
...
@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_gamma
(
gamma
=
gamma
,
gain
=
gain
)
...
...
@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
posterize
(
bits
=
bits
)
...
...
@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
solarize
(
threshold
=
threshold
)
...
...
@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
autocontrast_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
autocontrast
()
...
...
@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
equalize_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
equalize
()
...
...
@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
return
1.0
-
image
# type: ignore[no-any-return]
return
1.0
-
image
elif
image
.
dtype
==
torch
.
uint8
:
return
image
.
bitwise_not
()
else
:
# signed integer dtypes
...
...
@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
invert_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
invert
()
...
...
torchvision/prototype/transforms/functional/_deprecated.py
View file @
f71c4308
...
...
@@ -7,6 +7,8 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.transforms
import
functional
as
_F
from
._utils
import
is_simple_tensor
@
torch
.
jit
.
unused
def
to_grayscale
(
inpt
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
...
...
@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
()
and
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
old_color_space
=
None
elif
isinstance
(
inpt
,
torch
.
Tensor
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
old_color_space
=
datapoints
.
_image
.
_from_tensor_shape
(
inpt
.
shape
)
# type: ignore[arg-type]
else
:
old_color_space
=
None
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
call
=
", num_output_channels=3"
if
num_output_channels
==
3
else
""
replacement
=
(
f
"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
f71c4308
...
...
@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._utils
import
is_simple_tensor
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
...
...
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
horizontal_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
horizontal_flip
()
...
...
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
vertical_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
vertical_flip
()
...
...
@@ -241,9 +239,7 @@ def resize(
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
resize
(
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
...
...
@@ -744,9 +740,7 @@ def affine(
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
affine_image_tensor
(
inpt
,
angle
,
...
...
@@ -929,9 +923,7 @@ def rotate(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
rotate
(
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
...
...
@@ -1139,9 +1131,7 @@ def pad(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
...
...
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
crop
(
top
,
left
,
height
,
width
)
...
...
@@ -1476,9 +1464,7 @@ def perspective(
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
perspective_image_tensor
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
...
...
@@ -1639,9 +1625,7 @@ def elastic(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
elastic
(
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
...
...
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
center_crop_image_tensor
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
center_crop
(
output_size
)
...
...
@@ -1850,9 +1832,7 @@ def resized_crop(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
resized_crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
...
...
@@ -1935,9 +1915,7 @@ def five_crop(
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
five_crop_image_tensor
(
inpt
,
size
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
five_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
...
...
@@ -1991,9 +1969,7 @@ def ten_crop(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
ten_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
f71c4308
...
...
@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
...
...
@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_dimensions_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
channels
=
inpt
.
num_channels
...
...
@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_channels_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
return
inpt
.
num_channels
...
...
@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_spatial_size
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_spatial_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
return
list
(
inpt
.
spatial_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_spatial_size_image_pil
(
inpt
)
# type: ignore[no-any-return]
return
get_spatial_size_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
return
inpt
.
num_frames
...
...
@@ -336,9 +332,7 @@ def convert_color_space(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
old_color_space
is
None
:
raise
RuntimeError
(
"In order to convert the color space of simple tensors, "
...
...
@@ -443,9 +437,7 @@ def convert_dtype(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_dtype
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
convert_dtype_image_tensor
(
inpt
,
dtype
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
convert_dtype_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
dtype
)
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
f71c4308
...
...
@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
.
.
utils
import
is_simple_tensor
from
.
_
utils
import
is_simple_tensor
def
normalize_image_tensor
(
...
...
@@ -61,9 +61,9 @@ def normalize(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
el
se
:
el
if
not
is_simple_tensor
(
inpt
)
:
raise
TypeError
(
f
"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f
"but got
{
type
(
inpt
)
}
instead."
...
...
@@ -175,9 +175,7 @@ def gaussian_blur(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
gaussian_blur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
...
torchvision/prototype/transforms/functional/_temporal.py
View file @
f71c4308
...
...
@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
...
@@ -18,7 +20,7 @@ def uniform_temporal_subsample(
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)
):
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
if
temporal_dim
!=
-
4
and
inpt
.
ndim
-
4
!=
temporal_dim
:
...
...
torchvision/prototype/transforms/functional/_utils.py
0 → 100644
View file @
f71c4308
from
typing
import
Any
import
torch
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
torchvision/prototype/transforms/utils.py
View file @
f71c4308
...
...
@@ -3,16 +3,10 @@ from __future__ import annotations
from
typing
import
Any
,
Callable
,
List
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
from
torchvision._utils
import
sequence_to_str
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.datapoints._datapoint
import
Datapoint
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
def
is_simple_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
Datapoint
)
from
torchvision.prototype.transforms.functional
import
get_dimensions
,
get_spatial_size
,
is_simple_tensor
def
query_bounding_box
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBox
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment