Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ca012d39
"vscode:/vscode.git/clone" did not exist on "8cd071b6e9af48710d82192d68fe4e41d041d7e4"
Unverified
Commit
ca012d39
authored
Aug 16, 2023
by
Philip Meier
Committed by
GitHub
Aug 16, 2023
Browse files
make PIL kernels private (#7831)
parent
cdbbd666
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
125 additions
and
133 deletions
+125
-133
torchvision/transforms/v2/functional/_deprecated.py
torchvision/transforms/v2/functional/_deprecated.py
+1
-1
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+97
-99
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+13
-13
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+12
-14
torchvision/transforms/v2/functional/_type_conversion.py
torchvision/transforms/v2/functional/_type_conversion.py
+2
-6
No files found.
torchvision/transforms/v2/functional/_deprecated.py
View file @
ca012d39
...
@@ -10,7 +10,7 @@ from torchvision.transforms import functional as _F
...
@@ -10,7 +10,7 @@ from torchvision.transforms import functional as _F
def
to_tensor
(
inpt
:
Any
)
->
torch
.
Tensor
:
def
to_tensor
(
inpt
:
Any
)
->
torch
.
Tensor
:
warnings
.
warn
(
warnings
.
warn
(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image
_tensor
(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`."
"Instead, please use `to_image(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`."
)
)
return
_F
.
to_tensor
(
inpt
)
return
_F
.
to_tensor
(
inpt
)
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
ca012d39
...
@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
...
@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._meta
import
_get_size_image_pil
,
clamp_bounding_boxes
,
convert_format_bounding_boxes
from
._utils
import
_FillTypeJIT
,
_get_kernel
,
_register_five_ten_crop_kernel_internal
,
_register_kernel_internal
from
._utils
import
_FillTypeJIT
,
_get_kernel
,
_register_five_ten_crop_kernel_internal
,
_register_kernel_internal
...
@@ -41,7 +41,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
...
@@ -41,7 +41,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
def
horizontal_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
horizontal_flip_image
_tensor
(
inpt
)
return
horizontal_flip_image
(
inpt
)
_log_api_usage_once
(
horizontal_flip
)
_log_api_usage_once
(
horizontal_flip
)
...
@@ -51,18 +51,18 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
...
@@ -51,18 +51,18 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
horizontal_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
horizontal_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Image
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Image
)
def
horizontal_flip_image
_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
return
image
.
flip
(
-
1
)
@
_register_kernel_internal
(
horizontal_flip
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
horizontal_flip
,
PIL
.
Image
.
Image
)
def
horizontal_flip_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
PIL
.
Image
.
Image
:
def
_
horizontal_flip_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
PIL
.
Image
.
Image
:
return
_FP
.
hflip
(
image
)
return
_FP
.
hflip
(
image
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Mask
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Mask
)
def
horizontal_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
horizontal_flip_image
_tensor
(
mask
)
return
horizontal_flip_image
(
mask
)
def
horizontal_flip_bounding_boxes
(
def
horizontal_flip_bounding_boxes
(
...
@@ -92,12 +92,12 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) ->
...
@@ -92,12 +92,12 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) ->
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Video
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Video
)
def
horizontal_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
horizontal_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
horizontal_flip_image
_tensor
(
video
)
return
horizontal_flip_image
(
video
)
def
vertical_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
vertical_flip
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
vertical_flip_image
_tensor
(
inpt
)
return
vertical_flip_image
(
inpt
)
_log_api_usage_once
(
vertical_flip
)
_log_api_usage_once
(
vertical_flip
)
...
@@ -107,18 +107,18 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
...
@@ -107,18 +107,18 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
vertical_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
vertical_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Image
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Image
)
def
vertical_flip_image
_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
vertical_flip_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
2
)
return
image
.
flip
(
-
2
)
@
_register_kernel_internal
(
vertical_flip
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
vertical_flip
,
PIL
.
Image
.
Image
)
def
vertical_flip_image_pil
(
image
:
PIL
.
Image
)
->
PIL
.
Image
:
def
_
vertical_flip_image_pil
(
image
:
PIL
.
Image
)
->
PIL
.
Image
:
return
_FP
.
vflip
(
image
)
return
_FP
.
vflip
(
image
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Mask
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Mask
)
def
vertical_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
vertical_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image
_tensor
(
mask
)
return
vertical_flip_image
(
mask
)
def
vertical_flip_bounding_boxes
(
def
vertical_flip_bounding_boxes
(
...
@@ -148,7 +148,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da
...
@@ -148,7 +148,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Video
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Video
)
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image
_tensor
(
video
)
return
vertical_flip_image
(
video
)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
...
@@ -178,7 +178,7 @@ def resize(
...
@@ -178,7 +178,7 @@ def resize(
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
resize_image
_tensor
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
resize_image
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
_log_api_usage_once
(
resize
)
_log_api_usage_once
(
resize
)
...
@@ -188,7 +188,7 @@ def resize(
...
@@ -188,7 +188,7 @@ def resize(
@
_register_kernel_internal
(
resize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resize
,
datapoints
.
Image
)
@
_register_kernel_internal
(
resize
,
datapoints
.
Image
)
def
resize_image
_tensor
(
def
resize_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
size
:
List
[
int
],
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
...
@@ -267,7 +267,7 @@ def resize_image_tensor(
...
@@ -267,7 +267,7 @@ def resize_image_tensor(
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
def
resize_image_pil
(
def
_
resize_image_pil
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
size
:
Union
[
Sequence
[
int
],
int
],
size
:
Union
[
Sequence
[
int
],
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
...
@@ -289,7 +289,7 @@ def resize_image_pil(
...
@@ -289,7 +289,7 @@ def resize_image_pil(
@
_register_kernel_internal
(
resize
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
resize
,
PIL
.
Image
.
Image
)
def
_resize_image_pil_dispatch
(
def
_
_resize_image_pil_dispatch
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
size
:
Union
[
Sequence
[
int
],
int
],
size
:
Union
[
Sequence
[
int
],
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
...
@@ -298,7 +298,7 @@ def _resize_image_pil_dispatch(
...
@@ -298,7 +298,7 @@ def _resize_image_pil_dispatch(
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
if
antialias
is
False
:
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resize_image_pil
(
image
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
return
_
resize_image_pil
(
image
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
def
resize_mask
(
mask
:
torch
.
Tensor
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
resize_mask
(
mask
:
torch
.
Tensor
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
...
@@ -308,7 +308,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
...
@@ -308,7 +308,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
resize_image
_tensor
(
mask
,
size
=
size
,
interpolation
=
InterpolationMode
.
NEAREST
,
max_size
=
max_size
)
output
=
resize_image
(
mask
,
size
=
size
,
interpolation
=
InterpolationMode
.
NEAREST
,
max_size
=
max_size
)
if
needs_squeeze
:
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
@@ -360,7 +360,7 @@ def resize_video(
...
@@ -360,7 +360,7 @@ def resize_video(
max_size
:
Optional
[
int
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
resize_image
_tensor
(
video
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
resize_image
(
video
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
def
affine
(
def
affine
(
...
@@ -374,7 +374,7 @@ def affine(
...
@@ -374,7 +374,7 @@ def affine(
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
affine_image
_tensor
(
return
affine_image
(
inpt
,
inpt
,
angle
=
angle
,
angle
=
angle
,
translate
=
translate
,
translate
=
translate
,
...
@@ -648,7 +648,7 @@ def _affine_grid(
...
@@ -648,7 +648,7 @@ def _affine_grid(
@
_register_kernel_internal
(
affine
,
torch
.
Tensor
)
@
_register_kernel_internal
(
affine
,
torch
.
Tensor
)
@
_register_kernel_internal
(
affine
,
datapoints
.
Image
)
@
_register_kernel_internal
(
affine
,
datapoints
.
Image
)
def
affine_image
_tensor
(
def
affine_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
translate
:
List
[
float
],
...
@@ -700,7 +700,7 @@ def affine_image_tensor(
...
@@ -700,7 +700,7 @@ def affine_image_tensor(
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
def
affine_image_pil
(
def
_
affine_image_pil
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
angle
:
Union
[
int
,
float
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
translate
:
List
[
float
],
...
@@ -717,7 +717,7 @@ def affine_image_pil(
...
@@ -717,7 +717,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if
center
is
None
:
if
center
is
None
:
height
,
width
=
get_size_image_pil
(
image
)
height
,
width
=
_
get_size_image_pil
(
image
)
center
=
[
width
*
0.5
,
height
*
0.5
]
center
=
[
width
*
0.5
,
height
*
0.5
]
matrix
=
_get_inverse_affine_matrix
(
center
,
angle
,
translate
,
scale
,
shear
)
matrix
=
_get_inverse_affine_matrix
(
center
,
angle
,
translate
,
scale
,
shear
)
...
@@ -875,7 +875,7 @@ def affine_mask(
...
@@ -875,7 +875,7 @@ def affine_mask(
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
affine_image
_tensor
(
output
=
affine_image
(
mask
,
mask
,
angle
=
angle
,
angle
=
angle
,
translate
=
translate
,
translate
=
translate
,
...
@@ -926,7 +926,7 @@ def affine_video(
...
@@ -926,7 +926,7 @@ def affine_video(
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
affine_image
_tensor
(
return
affine_image
(
video
,
video
,
angle
=
angle
,
angle
=
angle
,
translate
=
translate
,
translate
=
translate
,
...
@@ -947,9 +947,7 @@ def rotate(
...
@@ -947,9 +947,7 @@ def rotate(
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
rotate_image_tensor
(
return
rotate_image
(
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
_log_api_usage_once
(
rotate
)
_log_api_usage_once
(
rotate
)
...
@@ -959,7 +957,7 @@ def rotate(
...
@@ -959,7 +957,7 @@ def rotate(
@
_register_kernel_internal
(
rotate
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rotate
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rotate
,
datapoints
.
Image
)
@
_register_kernel_internal
(
rotate
,
datapoints
.
Image
)
def
rotate_image
_tensor
(
def
rotate_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
angle
:
float
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
...
@@ -1004,7 +1002,7 @@ def rotate_image_tensor(
...
@@ -1004,7 +1002,7 @@ def rotate_image_tensor(
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
def
rotate_image_pil
(
def
_
rotate_image_pil
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
angle
:
float
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
...
@@ -1074,7 +1072,7 @@ def rotate_mask(
...
@@ -1074,7 +1072,7 @@ def rotate_mask(
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
rotate_image
_tensor
(
output
=
rotate_image
(
mask
,
mask
,
angle
=
angle
,
angle
=
angle
,
expand
=
expand
,
expand
=
expand
,
...
@@ -1111,7 +1109,7 @@ def rotate_video(
...
@@ -1111,7 +1109,7 @@ def rotate_video(
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
rotate_image
_tensor
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
rotate_image
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
def
pad
(
def
pad
(
...
@@ -1121,7 +1119,7 @@ def pad(
...
@@ -1121,7 +1119,7 @@ def pad(
padding_mode
:
str
=
"constant"
,
padding_mode
:
str
=
"constant"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
pad_image
_tensor
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
_log_api_usage_once
(
pad
)
_log_api_usage_once
(
pad
)
...
@@ -1155,7 +1153,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
...
@@ -1155,7 +1153,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
@
_register_kernel_internal
(
pad
,
torch
.
Tensor
)
@
_register_kernel_internal
(
pad
,
torch
.
Tensor
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Image
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Image
)
def
pad_image
_tensor
(
def
pad_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
padding
:
List
[
int
],
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
...
@@ -1253,7 +1251,7 @@ def _pad_with_vector_fill(
...
@@ -1253,7 +1251,7 @@ def _pad_with_vector_fill(
return
output
return
output
pad_image_pil
=
_register_kernel_internal
(
pad
,
PIL
.
Image
.
Image
)(
_FP
.
pad
)
_
pad_image_pil
=
_register_kernel_internal
(
pad
,
PIL
.
Image
.
Image
)(
_FP
.
pad
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Mask
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Mask
)
...
@@ -1275,7 +1273,7 @@ def pad_mask(
...
@@ -1275,7 +1273,7 @@ def pad_mask(
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
pad_image
_tensor
(
mask
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
output
=
pad_image
(
mask
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
if
needs_squeeze
:
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
@@ -1331,12 +1329,12 @@ def pad_video(
...
@@ -1331,12 +1329,12 @@ def pad_video(
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
padding_mode
:
str
=
"constant"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
pad_image
_tensor
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
pad_image
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
def
crop
(
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop
(
inpt
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
crop_image
_tensor
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
return
crop_image
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
_log_api_usage_once
(
crop
)
_log_api_usage_once
(
crop
)
...
@@ -1346,7 +1344,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to
...
@@ -1346,7 +1344,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to
@
_register_kernel_internal
(
crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
crop
,
datapoints
.
Image
)
@
_register_kernel_internal
(
crop
,
datapoints
.
Image
)
def
crop_image
_tensor
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop_image
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
h
,
w
=
image
.
shape
[
-
2
:]
h
,
w
=
image
.
shape
[
-
2
:]
right
=
left
+
width
right
=
left
+
width
...
@@ -1364,8 +1362,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
...
@@ -1364,8 +1362,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
return
image
[...,
top
:
bottom
,
left
:
right
]
return
image
[...,
top
:
bottom
,
left
:
right
]
crop_image_pil
=
_FP
.
crop
_
crop_image_pil
=
_FP
.
crop
_register_kernel_internal
(
crop
,
PIL
.
Image
.
Image
)(
crop_image_pil
)
_register_kernel_internal
(
crop
,
PIL
.
Image
.
Image
)(
_
crop_image_pil
)
def
crop_bounding_boxes
(
def
crop_bounding_boxes
(
...
@@ -1407,7 +1405,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
...
@@ -1407,7 +1405,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
crop_image
_tensor
(
mask
,
top
,
left
,
height
,
width
)
output
=
crop_image
(
mask
,
top
,
left
,
height
,
width
)
if
needs_squeeze
:
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
@@ -1417,7 +1415,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
...
@@ -1417,7 +1415,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
@
_register_kernel_internal
(
crop
,
datapoints
.
Video
)
@
_register_kernel_internal
(
crop
,
datapoints
.
Video
)
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
return
crop_image
_tensor
(
video
,
top
,
left
,
height
,
width
)
return
crop_image
(
video
,
top
,
left
,
height
,
width
)
def
perspective
(
def
perspective
(
...
@@ -1429,7 +1427,7 @@ def perspective(
...
@@ -1429,7 +1427,7 @@ def perspective(
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
perspective_image
_tensor
(
return
perspective_image
(
inpt
,
inpt
,
startpoints
=
startpoints
,
startpoints
=
startpoints
,
endpoints
=
endpoints
,
endpoints
=
endpoints
,
...
@@ -1500,7 +1498,7 @@ def _perspective_coefficients(
...
@@ -1500,7 +1498,7 @@ def _perspective_coefficients(
@
_register_kernel_internal
(
perspective
,
torch
.
Tensor
)
@
_register_kernel_internal
(
perspective
,
torch
.
Tensor
)
@
_register_kernel_internal
(
perspective
,
datapoints
.
Image
)
@
_register_kernel_internal
(
perspective
,
datapoints
.
Image
)
def
perspective_image
_tensor
(
def
perspective_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
...
@@ -1547,7 +1545,7 @@ def perspective_image_tensor(
...
@@ -1547,7 +1545,7 @@ def perspective_image_tensor(
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
def
perspective_image_pil
(
def
_
perspective_image_pil
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
...
@@ -1686,7 +1684,7 @@ def perspective_mask(
...
@@ -1686,7 +1684,7 @@ def perspective_mask(
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
perspective_image
_tensor
(
output
=
perspective_image
(
mask
,
startpoints
,
endpoints
,
interpolation
=
InterpolationMode
.
NEAREST
,
fill
=
fill
,
coefficients
=
coefficients
mask
,
startpoints
,
endpoints
,
interpolation
=
InterpolationMode
.
NEAREST
,
fill
=
fill
,
coefficients
=
coefficients
)
)
...
@@ -1724,7 +1722,7 @@ def perspective_video(
...
@@ -1724,7 +1722,7 @@ def perspective_video(
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
perspective_image
_tensor
(
return
perspective_image
(
video
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
video
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
)
...
@@ -1736,7 +1734,7 @@ def elastic(
...
@@ -1736,7 +1734,7 @@ def elastic(
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
elastic_image
_tensor
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
elastic_image
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
_log_api_usage_once
(
elastic
)
_log_api_usage_once
(
elastic
)
...
@@ -1749,7 +1747,7 @@ elastic_transform = elastic
...
@@ -1749,7 +1747,7 @@ elastic_transform = elastic
@
_register_kernel_internal
(
elastic
,
torch
.
Tensor
)
@
_register_kernel_internal
(
elastic
,
torch
.
Tensor
)
@
_register_kernel_internal
(
elastic
,
datapoints
.
Image
)
@
_register_kernel_internal
(
elastic
,
datapoints
.
Image
)
def
elastic_image
_tensor
(
def
elastic_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
...
@@ -1809,14 +1807,14 @@ def elastic_image_tensor(
...
@@ -1809,14 +1807,14 @@ def elastic_image_tensor(
@
_register_kernel_internal
(
elastic
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
elastic
,
PIL
.
Image
.
Image
)
def
elastic_image_pil
(
def
_
elastic_image_pil
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
t_img
=
pil_to_tensor
(
image
)
output
=
elastic_image
_tensor
(
t_img
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
output
=
elastic_image
(
t_img
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
...
@@ -1910,7 +1908,7 @@ def elastic_mask(
...
@@ -1910,7 +1908,7 @@ def elastic_mask(
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
elastic_image
_tensor
(
mask
,
displacement
=
displacement
,
interpolation
=
InterpolationMode
.
NEAREST
,
fill
=
fill
)
output
=
elastic_image
(
mask
,
displacement
=
displacement
,
interpolation
=
InterpolationMode
.
NEAREST
,
fill
=
fill
)
if
needs_squeeze
:
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
@@ -1933,12 +1931,12 @@ def elastic_video(
...
@@ -1933,12 +1931,12 @@ def elastic_video(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
elastic_image
_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
elastic_image
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
def
center_crop
(
inpt
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
def
center_crop
(
inpt
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
center_crop_image
_tensor
(
inpt
,
output_size
=
output_size
)
return
center_crop_image
(
inpt
,
output_size
=
output_size
)
_log_api_usage_once
(
center_crop
)
_log_api_usage_once
(
center_crop
)
...
@@ -1975,7 +1973,7 @@ def _center_crop_compute_crop_anchor(
...
@@ -1975,7 +1973,7 @@ def _center_crop_compute_crop_anchor(
@
_register_kernel_internal
(
center_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
center_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Image
)
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Image
)
def
center_crop_image
_tensor
(
image
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
def
center_crop_image
(
image
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
shape
=
image
.
shape
shape
=
image
.
shape
if
image
.
numel
()
==
0
:
if
image
.
numel
()
==
0
:
...
@@ -1995,20 +1993,20 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
...
@@ -1995,20 +1993,20 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@
_register_kernel_internal
(
center_crop
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
center_crop
,
PIL
.
Image
.
Image
)
def
center_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
output_size
:
List
[
int
])
->
PIL
.
Image
.
Image
:
def
_
center_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
output_size
:
List
[
int
])
->
PIL
.
Image
.
Image
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
image_height
,
image_width
=
_
get_size_image_pil
(
image
)
if
crop_height
>
image_height
or
crop_width
>
image_width
:
if
crop_height
>
image_height
or
crop_width
>
image_width
:
padding_ltrb
=
_center_crop_compute_padding
(
crop_height
,
crop_width
,
image_height
,
image_width
)
padding_ltrb
=
_center_crop_compute_padding
(
crop_height
,
crop_width
,
image_height
,
image_width
)
image
=
pad_image_pil
(
image
,
padding_ltrb
,
fill
=
0
)
image
=
_
pad_image_pil
(
image
,
padding_ltrb
,
fill
=
0
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
image_height
,
image_width
=
_
get_size_image_pil
(
image
)
if
crop_width
==
image_width
and
crop_height
==
image_height
:
if
crop_width
==
image_width
and
crop_height
==
image_height
:
return
image
return
image
crop_top
,
crop_left
=
_center_crop_compute_crop_anchor
(
crop_height
,
crop_width
,
image_height
,
image_width
)
crop_top
,
crop_left
=
_center_crop_compute_crop_anchor
(
crop_height
,
crop_width
,
image_height
,
image_width
)
return
crop_image_pil
(
image
,
crop_top
,
crop_left
,
crop_height
,
crop_width
)
return
_
crop_image_pil
(
image
,
crop_top
,
crop_left
,
crop_height
,
crop_width
)
def
center_crop_bounding_boxes
(
def
center_crop_bounding_boxes
(
...
@@ -2042,7 +2040,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
...
@@ -2042,7 +2040,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
else
:
else
:
needs_squeeze
=
False
needs_squeeze
=
False
output
=
center_crop_image
_tensor
(
image
=
mask
,
output_size
=
output_size
)
output
=
center_crop_image
(
image
=
mask
,
output_size
=
output_size
)
if
needs_squeeze
:
if
needs_squeeze
:
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
@@ -2052,7 +2050,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
...
@@ -2052,7 +2050,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Video
)
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Video
)
def
center_crop_video
(
video
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
def
center_crop_video
(
video
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
return
center_crop_image
_tensor
(
video
,
output_size
)
return
center_crop_image
(
video
,
output_size
)
def
resized_crop
(
def
resized_crop
(
...
@@ -2066,7 +2064,7 @@ def resized_crop(
...
@@ -2066,7 +2064,7 @@ def resized_crop(
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
resized_crop_image
_tensor
(
return
resized_crop_image
(
inpt
,
inpt
,
top
=
top
,
top
=
top
,
left
=
left
,
left
=
left
,
...
@@ -2094,7 +2092,7 @@ def resized_crop(
...
@@ -2094,7 +2092,7 @@ def resized_crop(
@
_register_kernel_internal
(
resized_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resized_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Image
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Image
)
def
resized_crop_image
_tensor
(
def
resized_crop_image
(
image
:
torch
.
Tensor
,
image
:
torch
.
Tensor
,
top
:
int
,
top
:
int
,
left
:
int
,
left
:
int
,
...
@@ -2104,11 +2102,11 @@ def resized_crop_image_tensor(
...
@@ -2104,11 +2102,11 @@ def resized_crop_image_tensor(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
image
=
crop_image
_tensor
(
image
,
top
,
left
,
height
,
width
)
image
=
crop_image
(
image
,
top
,
left
,
height
,
width
)
return
resize_image
_tensor
(
image
,
size
,
interpolation
=
interpolation
,
antialias
=
antialias
)
return
resize_image
(
image
,
size
,
interpolation
=
interpolation
,
antialias
=
antialias
)
def
resized_crop_image_pil
(
def
_
resized_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
top
:
int
,
top
:
int
,
left
:
int
,
left
:
int
,
...
@@ -2117,12 +2115,12 @@ def resized_crop_image_pil(
...
@@ -2117,12 +2115,12 @@ def resized_crop_image_pil(
size
:
List
[
int
],
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
image
=
crop_image_pil
(
image
,
top
,
left
,
height
,
width
)
image
=
_
crop_image_pil
(
image
,
top
,
left
,
height
,
width
)
return
resize_image_pil
(
image
,
size
,
interpolation
=
interpolation
)
return
_
resize_image_pil
(
image
,
size
,
interpolation
=
interpolation
)
@
_register_kernel_internal
(
resized_crop
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
resized_crop
,
PIL
.
Image
.
Image
)
def
resized_crop_image_pil_dispatch
(
def
_
resized_crop_image_pil_dispatch
(
image
:
PIL
.
Image
.
Image
,
image
:
PIL
.
Image
.
Image
,
top
:
int
,
top
:
int
,
left
:
int
,
left
:
int
,
...
@@ -2134,7 +2132,7 @@ def resized_crop_image_pil_dispatch(
...
@@ -2134,7 +2132,7 @@ def resized_crop_image_pil_dispatch(
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
if
antialias
is
False
:
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resized_crop_image_pil
(
return
_
resized_crop_image_pil
(
image
,
image
,
top
=
top
,
top
=
top
,
left
=
left
,
left
=
left
,
...
@@ -2201,7 +2199,7 @@ def resized_crop_video(
...
@@ -2201,7 +2199,7 @@ def resized_crop_video(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
resized_crop_image
_tensor
(
return
resized_crop_image
(
video
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
video
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
)
...
@@ -2210,7 +2208,7 @@ def five_crop(
...
@@ -2210,7 +2208,7 @@ def five_crop(
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
five_crop_image
_tensor
(
inpt
,
size
=
size
)
return
five_crop_image
(
inpt
,
size
=
size
)
_log_api_usage_once
(
five_crop
)
_log_api_usage_once
(
five_crop
)
...
@@ -2234,7 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
...
@@ -2234,7 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
datapoints
.
Image
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
datapoints
.
Image
)
def
five_crop_image
_tensor
(
def
five_crop_image
(
image
:
torch
.
Tensor
,
size
:
List
[
int
]
image
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
crop_height
,
crop_width
=
_parse_five_crop_size
(
size
)
crop_height
,
crop_width
=
_parse_five_crop_size
(
size
)
...
@@ -2243,30 +2241,30 @@ def five_crop_image_tensor(
...
@@ -2243,30 +2241,30 @@ def five_crop_image_tensor(
if
crop_width
>
image_width
or
crop_height
>
image_height
:
if
crop_width
>
image_width
or
crop_height
>
image_height
:
raise
ValueError
(
f
"Requested crop size
{
size
}
is bigger than input size
{
(
image_height
,
image_width
)
}
"
)
raise
ValueError
(
f
"Requested crop size
{
size
}
is bigger than input size
{
(
image_height
,
image_width
)
}
"
)
tl
=
crop_image
_tensor
(
image
,
0
,
0
,
crop_height
,
crop_width
)
tl
=
crop_image
(
image
,
0
,
0
,
crop_height
,
crop_width
)
tr
=
crop_image
_tensor
(
image
,
0
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
tr
=
crop_image
(
image
,
0
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
bl
=
crop_image
_tensor
(
image
,
image_height
-
crop_height
,
0
,
crop_height
,
crop_width
)
bl
=
crop_image
(
image
,
image_height
-
crop_height
,
0
,
crop_height
,
crop_width
)
br
=
crop_image
_tensor
(
image
,
image_height
-
crop_height
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
br
=
crop_image
(
image
,
image_height
-
crop_height
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
center
=
center_crop_image
_tensor
(
image
,
[
crop_height
,
crop_width
])
center
=
center_crop_image
(
image
,
[
crop_height
,
crop_width
])
return
tl
,
tr
,
bl
,
br
,
center
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
PIL
.
Image
.
Image
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
PIL
.
Image
.
Image
)
def
five_crop_image_pil
(
def
_
five_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
crop_height
,
crop_width
=
_parse_five_crop_size
(
size
)
crop_height
,
crop_width
=
_parse_five_crop_size
(
size
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
image_height
,
image_width
=
_
get_size_image_pil
(
image
)
if
crop_width
>
image_width
or
crop_height
>
image_height
:
if
crop_width
>
image_width
or
crop_height
>
image_height
:
raise
ValueError
(
f
"Requested crop size
{
size
}
is bigger than input size
{
(
image_height
,
image_width
)
}
"
)
raise
ValueError
(
f
"Requested crop size
{
size
}
is bigger than input size
{
(
image_height
,
image_width
)
}
"
)
tl
=
crop_image_pil
(
image
,
0
,
0
,
crop_height
,
crop_width
)
tl
=
_
crop_image_pil
(
image
,
0
,
0
,
crop_height
,
crop_width
)
tr
=
crop_image_pil
(
image
,
0
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
tr
=
_
crop_image_pil
(
image
,
0
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
bl
=
crop_image_pil
(
image
,
image_height
-
crop_height
,
0
,
crop_height
,
crop_width
)
bl
=
_
crop_image_pil
(
image
,
image_height
-
crop_height
,
0
,
crop_height
,
crop_width
)
br
=
crop_image_pil
(
image
,
image_height
-
crop_height
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
br
=
_
crop_image_pil
(
image
,
image_height
-
crop_height
,
image_width
-
crop_width
,
crop_height
,
crop_width
)
center
=
center_crop_image_pil
(
image
,
[
crop_height
,
crop_width
])
center
=
_
center_crop_image_pil
(
image
,
[
crop_height
,
crop_width
])
return
tl
,
tr
,
bl
,
br
,
center
return
tl
,
tr
,
bl
,
br
,
center
...
@@ -2275,7 +2273,7 @@ def five_crop_image_pil(
...
@@ -2275,7 +2273,7 @@ def five_crop_image_pil(
def
five_crop_video
(
def
five_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
]
video
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
five_crop_image
_tensor
(
video
,
size
)
return
five_crop_image
(
video
,
size
)
def
ten_crop
(
def
ten_crop
(
...
@@ -2293,7 +2291,7 @@ def ten_crop(
...
@@ -2293,7 +2291,7 @@ def ten_crop(
torch
.
Tensor
,
torch
.
Tensor
,
]:
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
ten_crop_image
_tensor
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
return
ten_crop_image
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
_log_api_usage_once
(
ten_crop
)
_log_api_usage_once
(
ten_crop
)
...
@@ -2303,7 +2301,7 @@ def ten_crop(
...
@@ -2303,7 +2301,7 @@ def ten_crop(
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
datapoints
.
Image
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
datapoints
.
Image
)
def
ten_crop_image
_tensor
(
def
ten_crop_image
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -2317,20 +2315,20 @@ def ten_crop_image_tensor(
...
@@ -2317,20 +2315,20 @@ def ten_crop_image_tensor(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
]:
]:
non_flipped
=
five_crop_image
_tensor
(
image
,
size
)
non_flipped
=
five_crop_image
(
image
,
size
)
if
vertical_flip
:
if
vertical_flip
:
image
=
vertical_flip_image
_tensor
(
image
)
image
=
vertical_flip_image
(
image
)
else
:
else
:
image
=
horizontal_flip_image
_tensor
(
image
)
image
=
horizontal_flip_image
(
image
)
flipped
=
five_crop_image
_tensor
(
image
,
size
)
flipped
=
five_crop_image
(
image
,
size
)
return
non_flipped
+
flipped
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
PIL
.
Image
.
Image
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
PIL
.
Image
.
Image
)
def
ten_crop_image_pil
(
def
_
ten_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
...
@@ -2344,14 +2342,14 @@ def ten_crop_image_pil(
...
@@ -2344,14 +2342,14 @@ def ten_crop_image_pil(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
]:
]:
non_flipped
=
five_crop_image_pil
(
image
,
size
)
non_flipped
=
_
five_crop_image_pil
(
image
,
size
)
if
vertical_flip
:
if
vertical_flip
:
image
=
vertical_flip_image_pil
(
image
)
image
=
_
vertical_flip_image_pil
(
image
)
else
:
else
:
image
=
horizontal_flip_image_pil
(
image
)
image
=
_
horizontal_flip_image_pil
(
image
)
flipped
=
five_crop_image_pil
(
image
,
size
)
flipped
=
_
five_crop_image_pil
(
image
,
size
)
return
non_flipped
+
flipped
return
non_flipped
+
flipped
...
@@ -2371,4 +2369,4 @@ def ten_crop_video(
...
@@ -2371,4 +2369,4 @@ def ten_crop_video(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
]:
]:
return
ten_crop_image
_tensor
(
video
,
size
,
vertical_flip
=
vertical_flip
)
return
ten_crop_image
(
video
,
size
,
vertical_flip
=
vertical_flip
)
torchvision/transforms/v2/functional/_meta.py
View file @
ca012d39
...
@@ -13,7 +13,7 @@ from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
...
@@ -13,7 +13,7 @@ from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image
_tensor
(
inpt
)
return
get_dimensions_image
(
inpt
)
_log_api_usage_once
(
get_dimensions
)
_log_api_usage_once
(
get_dimensions
)
...
@@ -23,7 +23,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
...
@@ -23,7 +23,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@
_register_kernel_internal
(
get_dimensions
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_dimensions
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_dimensions_image
_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions_image
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
chw
=
list
(
image
.
shape
[
-
3
:])
ndims
=
len
(
chw
)
ndims
=
len
(
chw
)
if
ndims
==
3
:
if
ndims
==
3
:
...
@@ -35,17 +35,17 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
...
@@ -35,17 +35,17 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
get_dimensions_image_pil
=
_register_kernel_internal
(
get_dimensions
,
PIL
.
Image
.
Image
)(
_FP
.
get_dimensions
)
_
get_dimensions_image_pil
=
_register_kernel_internal
(
get_dimensions
,
PIL
.
Image
.
Image
)(
_FP
.
get_dimensions
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_dimensions_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_dimensions_image
_tensor
(
video
)
return
get_dimensions_image
(
video
)
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image
_tensor
(
inpt
)
return
get_num_channels_image
(
inpt
)
_log_api_usage_once
(
get_num_channels
)
_log_api_usage_once
(
get_num_channels
)
...
@@ -55,7 +55,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
...
@@ -55,7 +55,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@
_register_kernel_internal
(
get_num_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_num_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_num_channels_image
_tensor
(
image
:
torch
.
Tensor
)
->
int
:
def
get_num_channels_image
(
image
:
torch
.
Tensor
)
->
int
:
chw
=
image
.
shape
[
-
3
:]
chw
=
image
.
shape
[
-
3
:]
ndims
=
len
(
chw
)
ndims
=
len
(
chw
)
if
ndims
==
3
:
if
ndims
==
3
:
...
@@ -66,12 +66,12 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
...
@@ -66,12 +66,12 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
get_num_channels_image_pil
=
_register_kernel_internal
(
get_num_channels
,
PIL
.
Image
.
Image
)(
_FP
.
get_image_num_channels
)
_
get_num_channels_image_pil
=
_register_kernel_internal
(
get_num_channels
,
PIL
.
Image
.
Image
)(
_FP
.
get_image_num_channels
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_num_channels_video
(
video
:
torch
.
Tensor
)
->
int
:
def
get_num_channels_video
(
video
:
torch
.
Tensor
)
->
int
:
return
get_num_channels_image
_tensor
(
video
)
return
get_num_channels_image
(
video
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
...
@@ -81,7 +81,7 @@ get_image_num_channels = get_num_channels
...
@@ -81,7 +81,7 @@ get_image_num_channels = get_num_channels
def
get_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_size
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_size_image
_tensor
(
inpt
)
return
get_size_image
(
inpt
)
_log_api_usage_once
(
get_size
)
_log_api_usage_once
(
get_size
)
...
@@ -91,7 +91,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
...
@@ -91,7 +91,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@
_register_kernel_internal
(
get_size
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_size
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_size_image
_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_size_image
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
hw
=
list
(
image
.
shape
[
-
2
:])
hw
=
list
(
image
.
shape
[
-
2
:])
ndims
=
len
(
hw
)
ndims
=
len
(
hw
)
if
ndims
==
2
:
if
ndims
==
2
:
...
@@ -101,19 +101,19 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
...
@@ -101,19 +101,19 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
@
_register_kernel_internal
(
get_size
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
get_size
,
PIL
.
Image
.
Image
)
def
get_size_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
List
[
int
]:
def
_
get_size_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
List
[
int
]:
width
,
height
=
_FP
.
get_image_size
(
image
)
width
,
height
=
_FP
.
get_image_size
(
image
)
return
[
height
,
width
]
return
[
height
,
width
]
@
_register_kernel_internal
(
get_size
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_size_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_size_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image
_tensor
(
video
)
return
get_size_image
(
video
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
get_size_mask
(
mask
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_size_mask
(
mask
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image
_tensor
(
mask
)
return
get_size_image
(
mask
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
ca012d39
...
@@ -21,7 +21,7 @@ def normalize(
...
@@ -21,7 +21,7 @@ def normalize(
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
normalize_image
_tensor
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
return
normalize_image
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
_log_api_usage_once
(
normalize
)
_log_api_usage_once
(
normalize
)
...
@@ -31,9 +31,7 @@ def normalize(
...
@@ -31,9 +31,7 @@ def normalize(
@
_register_kernel_internal
(
normalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
normalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
normalize
,
datapoints
.
Image
)
@
_register_kernel_internal
(
normalize
,
datapoints
.
Image
)
def
normalize_image_tensor
(
def
normalize_image
(
image
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
image
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
image
.
is_floating_point
():
if
not
image
.
is_floating_point
():
raise
TypeError
(
f
"Input tensor should be a float tensor. Got
{
image
.
dtype
}
."
)
raise
TypeError
(
f
"Input tensor should be a float tensor. Got
{
image
.
dtype
}
."
)
...
@@ -68,12 +66,12 @@ def normalize_image_tensor(
...
@@ -68,12 +66,12 @@ def normalize_image_tensor(
@
_register_kernel_internal
(
normalize
,
datapoints
.
Video
)
@
_register_kernel_internal
(
normalize
,
datapoints
.
Video
)
def
normalize_video
(
video
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
def
normalize_video
(
video
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
normalize_image
_tensor
(
video
,
mean
,
std
,
inplace
=
inplace
)
return
normalize_image
(
video
,
mean
,
std
,
inplace
=
inplace
)
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image
_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
gaussian_blur_image
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
_log_api_usage_once
(
gaussian_blur
)
_log_api_usage_once
(
gaussian_blur
)
...
@@ -99,7 +97,7 @@ def _get_gaussian_kernel2d(
...
@@ -99,7 +97,7 @@ def _get_gaussian_kernel2d(
@
_register_kernel_internal
(
gaussian_blur
,
torch
.
Tensor
)
@
_register_kernel_internal
(
gaussian_blur
,
torch
.
Tensor
)
@
_register_kernel_internal
(
gaussian_blur
,
datapoints
.
Image
)
@
_register_kernel_internal
(
gaussian_blur
,
datapoints
.
Image
)
def
gaussian_blur_image
_tensor
(
def
gaussian_blur_image
(
image
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
image
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: consider deprecating integers from sigma on the future
# TODO: consider deprecating integers from sigma on the future
...
@@ -164,11 +162,11 @@ def gaussian_blur_image_tensor(
...
@@ -164,11 +162,11 @@ def gaussian_blur_image_tensor(
@
_register_kernel_internal
(
gaussian_blur
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
gaussian_blur
,
PIL
.
Image
.
Image
)
def
gaussian_blur_image_pil
(
def
_
gaussian_blur_image_pil
(
image
:
PIL
.
Image
.
Image
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
image
:
PIL
.
Image
.
Image
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
t_img
=
pil_to_tensor
(
image
)
output
=
gaussian_blur_image
_tensor
(
t_img
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
output
=
gaussian_blur_image
(
t_img
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
...
@@ -176,12 +174,12 @@ def gaussian_blur_image_pil(
...
@@ -176,12 +174,12 @@ def gaussian_blur_image_pil(
def
gaussian_blur_video
(
def
gaussian_blur_video
(
video
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
video
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
gaussian_blur_image
_tensor
(
video
,
kernel_size
,
sigma
)
return
gaussian_blur_image
(
video
,
kernel_size
,
sigma
)
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image
_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
return
to_dtype_image
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
_log_api_usage_once
(
to_dtype
)
_log_api_usage_once
(
to_dtype
)
...
@@ -206,7 +204,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
...
@@ -206,7 +204,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
@
_register_kernel_internal
(
to_dtype
,
torch
.
Tensor
)
@
_register_kernel_internal
(
to_dtype
,
torch
.
Tensor
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Image
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Image
)
def
to_dtype_image
_tensor
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
def
to_dtype_image
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
image
.
dtype
==
dtype
:
if
image
.
dtype
==
dtype
:
return
image
return
image
...
@@ -260,12 +258,12 @@ def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float,
...
@@ -260,12 +258,12 @@ def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float,
# We encourage users to use to_dtype() instead but we keep this for BC
# We encourage users to use to_dtype() instead but we keep this for BC
def
convert_image_dtype
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
def
convert_image_dtype
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
return
to_dtype_image
_tensor
(
image
,
dtype
=
dtype
,
scale
=
True
)
return
to_dtype_image
(
image
,
dtype
=
dtype
,
scale
=
True
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Video
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Video
)
def
to_dtype_video
(
video
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
def
to_dtype_video
(
video
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
return
to_dtype_image
_tensor
(
video
,
dtype
,
scale
=
scale
)
return
to_dtype_image
(
video
,
dtype
,
scale
=
scale
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
...
...
torchvision/transforms/v2/functional/_type_conversion.py
View file @
ca012d39
...
@@ -8,7 +8,7 @@ from torchvision.transforms import functional as _F
...
@@ -8,7 +8,7 @@ from torchvision.transforms import functional as _F
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
def
to_image
_tensor
(
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
])
->
datapoints
.
Image
:
def
to_image
(
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
])
->
datapoints
.
Image
:
if
isinstance
(
inpt
,
np
.
ndarray
):
if
isinstance
(
inpt
,
np
.
ndarray
):
output
=
torch
.
from_numpy
(
inpt
).
permute
((
2
,
0
,
1
)).
contiguous
()
output
=
torch
.
from_numpy
(
inpt
).
permute
((
2
,
0
,
1
)).
contiguous
()
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
...
@@ -20,9 +20,5 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> d
...
@@ -20,9 +20,5 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> d
return
datapoints
.
Image
(
output
)
return
datapoints
.
Image
(
output
)
to_image
_pil
=
_F
.
to_pil_image
to_
pil_
image
=
_F
.
to_pil_image
pil_to_tensor
=
_F
.
pil_to_tensor
pil_to_tensor
=
_F
.
pil_to_tensor
# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
to_pil_image
=
to_image_pil
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment