Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b570f2c1
Unverified
Commit
b570f2c1
authored
Feb 14, 2023
by
Philip Meier
Committed by
GitHub
Feb 14, 2023
Browse files
Undeprecate PIL int constants for interpolation (#7241)
parent
70745705
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
110 additions
and
73 deletions
+110
-73
test/test_functional_tensor.py
test/test_functional_tensor.py
+31
-6
test/test_onnx.py
test/test_onnx.py
+1
-2
test/test_transforms.py
test/test_transforms.py
+11
-7
test/test_transforms_tensor.py
test/test_transforms_tensor.py
+4
-4
torchvision/transforms/_pil_constants.py
torchvision/transforms/_pil_constants.py
+0
-25
torchvision/transforms/functional.py
torchvision/transforms/functional.py
+32
-11
torchvision/transforms/functional_pil.py
torchvision/transforms/functional_pil.py
+8
-9
torchvision/transforms/transforms.py
torchvision/transforms/transforms.py
+23
-9
No files found.
test/test_functional_tensor.py
View file @
b570f2c1
...
...
@@ -7,6 +7,7 @@ from functools import partial
from
typing
import
Sequence
import
numpy
as
np
import
PIL.Image
import
pytest
import
torch
import
torchvision.transforms
as
T
...
...
@@ -144,6 +145,12 @@ class TestRotate:
center
=
(
20
,
22
)
_test_fn_on_batch
(
batch_tensors
,
F
.
rotate
,
angle
=
32
,
interpolation
=
NEAREST
,
expand
=
True
,
center
=
center
)
def
test_rotate_interpolation_type
(
self
):
tensor
,
_
=
_create_data
(
26
,
26
)
res1
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
PIL
.
Image
.
BILINEAR
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
class
TestAffine
:
...
...
@@ -350,6 +357,14 @@ class TestAffine:
_test_fn_on_batch
(
batch_tensors
,
F
.
affine
,
angle
=-
43
,
translate
=
[
-
3
,
4
],
scale
=
1.2
,
shear
=
[
4.0
,
5.0
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_interpolation_type
(
self
,
device
):
tensor
,
pil_img
=
_create_data
(
26
,
26
,
device
=
device
)
res1
=
F
.
affine
(
tensor
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
PIL
.
Image
.
BILINEAR
)
res2
=
F
.
affine
(
tensor
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
def
_get_data_dims_and_points_for_perspective
():
# Ideally we would parametrize independently over data dims and points, but
...
...
@@ -448,6 +463,16 @@ def test_perspective_batch(device, dims_and_points, dt):
)
def
test_perspective_interpolation_type
():
spoints
=
[[
0
,
0
],
[
33
,
0
],
[
33
,
25
],
[
0
,
25
]]
epoints
=
[[
3
,
2
],
[
32
,
3
],
[
30
,
24
],
[
2
,
25
]]
tensor
=
torch
.
randint
(
0
,
256
,
(
3
,
26
,
26
))
res1
=
F
.
perspective
(
tensor
,
startpoints
=
spoints
,
endpoints
=
epoints
,
interpolation
=
PIL
.
Image
.
BILINEAR
)
res2
=
F
.
perspective
(
tensor
,
startpoints
=
spoints
,
endpoints
=
epoints
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"dt"
,
[
None
,
torch
.
float32
,
torch
.
float64
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -489,9 +514,7 @@ def test_resize(device, dt, size, max_size, interpolation):
assert
resized_tensor
.
size
()[
1
:]
==
resized_pil_img
.
size
[::
-
1
]
if
interpolation
not
in
[
NEAREST
,
]:
if
interpolation
!=
NEAREST
:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
...
...
@@ -504,9 +527,7 @@ def test_resize(device, dt, size, max_size, interpolation):
_assert_approx_equal_tensor_to_pil
(
resized_tensor_f
,
resized_pil_img
,
tol
=
8.0
)
if
isinstance
(
size
,
int
):
script_size
=
[
size
,
]
script_size
=
[
size
]
else
:
script_size
=
size
...
...
@@ -523,6 +544,10 @@ def test_resize_asserts(device):
tensor
,
pil_img
=
_create_data
(
26
,
36
,
device
=
device
)
res1
=
F
.
resize
(
tensor
,
size
=
32
,
interpolation
=
PIL
.
Image
.
BILINEAR
)
res2
=
F
.
resize
(
tensor
,
size
=
32
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
for
img
in
(
tensor
,
pil_img
):
exp_msg
=
"max_size should only be passed if size specifies the length of the smaller edge"
with
pytest
.
raises
(
ValueError
,
match
=
exp_msg
):
...
...
test/test_onnx.py
View file @
b570f2c1
...
...
@@ -407,13 +407,12 @@ class TestONNXExporter:
def
get_image
(
self
,
rel_path
:
str
,
size
:
Tuple
[
int
,
int
])
->
torch
.
Tensor
:
import
os
import
torchvision.transforms._pil_constants
as
_pil_constants
from
PIL
import
Image
from
torchvision.transforms
import
functional
as
F
data_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"assets"
)
path
=
os
.
path
.
join
(
data_dir
,
*
rel_path
.
split
(
"/"
))
image
=
Image
.
open
(
path
).
convert
(
"RGB"
).
resize
(
size
,
_pil_constants
.
BILINEAR
)
image
=
Image
.
open
(
path
).
convert
(
"RGB"
).
resize
(
size
,
Image
.
BILINEAR
)
return
F
.
convert_image_dtype
(
F
.
pil_to_tensor
(
image
))
...
...
test/test_transforms.py
View file @
b570f2c1
...
...
@@ -9,7 +9,6 @@ import numpy as np
import
pytest
import
torch
import
torchvision.transforms
as
transforms
import
torchvision.transforms._pil_constants
as
_pil_constants
import
torchvision.transforms.functional
as
F
import
torchvision.transforms.functional_tensor
as
F_t
from
PIL
import
Image
...
...
@@ -175,7 +174,7 @@ class TestAccImage:
def
test_accimage_resize
(
self
):
trans
=
transforms
.
Compose
(
[
transforms
.
Resize
(
256
,
interpolation
=
_pil_constants
.
LINEAR
),
transforms
.
Resize
(
256
,
interpolation
=
Image
.
LINEAR
),
transforms
.
PILToTensor
(),
transforms
.
ConvertImageDtype
(
dtype
=
torch
.
float
),
]
...
...
@@ -1533,10 +1532,10 @@ def test_ten_crop(should_vflip, single_dim):
five_crop
.
__repr__
()
if
should_vflip
:
vflipped_img
=
img
.
transpose
(
_pil_constants
.
FLIP_TOP_BOTTOM
)
vflipped_img
=
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
expected_output
+=
five_crop
(
vflipped_img
)
else
:
hflipped_img
=
img
.
transpose
(
_pil_constants
.
FLIP_LEFT_RIGHT
)
hflipped_img
=
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
expected_output
+=
five_crop
(
hflipped_img
)
assert
len
(
results
)
==
10
...
...
@@ -1883,6 +1882,9 @@ def test_random_rotation():
# Checking if RandomRotation can be printed as string
t
.
__repr__
()
t
=
transforms
.
RandomRotation
((
-
10
,
10
),
interpolation
=
Image
.
BILINEAR
)
assert
t
.
interpolation
==
transforms
.
InterpolationMode
.
BILINEAR
def
test_random_rotation_error
():
# assert fill being either a Sequence or a Number
...
...
@@ -2212,6 +2214,9 @@ def test_random_affine():
t
=
transforms
.
RandomAffine
(
10
,
interpolation
=
transforms
.
InterpolationMode
.
BILINEAR
)
assert
"bilinear"
in
t
.
__repr__
()
t
=
transforms
.
RandomAffine
(
10
,
interpolation
=
Image
.
BILINEAR
)
assert
t
.
interpolation
==
transforms
.
InterpolationMode
.
BILINEAR
def
test_elastic_transformation
():
with
pytest
.
raises
(
TypeError
,
match
=
r
"alpha should be float or a sequence of floats"
):
...
...
@@ -2228,9 +2233,8 @@ def test_elastic_transformation():
with
pytest
.
raises
(
ValueError
,
match
=
r
"sigma is a sequence its length should be 2"
):
transforms
.
ElasticTransform
(
alpha
=
2.0
,
sigma
=
[
1.0
,
0.0
,
1.0
])
with
pytest
.
warns
(
UserWarning
,
match
=
r
"Argument interpolation should be of type InterpolationMode"
):
t
=
transforms
.
transforms
.
ElasticTransform
(
alpha
=
2.0
,
sigma
=
2.0
,
interpolation
=
2
)
assert
t
.
interpolation
==
transforms
.
InterpolationMode
.
BILINEAR
t
=
transforms
.
transforms
.
ElasticTransform
(
alpha
=
2.0
,
sigma
=
2.0
,
interpolation
=
Image
.
BILINEAR
)
assert
t
.
interpolation
==
transforms
.
InterpolationMode
.
BILINEAR
with
pytest
.
raises
(
TypeError
,
match
=
r
"fill should be int or float"
):
transforms
.
ElasticTransform
(
alpha
=
1.0
,
sigma
=
1.0
,
fill
=
{})
...
...
test/test_transforms_tensor.py
View file @
b570f2c1
...
...
@@ -3,9 +3,9 @@ import sys
import
warnings
import
numpy
as
np
import
PIL.Image
import
pytest
import
torch
import
torchvision.transforms._pil_constants
as
_pil_constants
from
common_utils
import
(
_assert_approx_equal_tensor_to_pil
,
_assert_equal_tensor_to_pil
,
...
...
@@ -657,13 +657,13 @@ def test_autoaugment__op_apply_shear(interpolation, mode):
matrix
=
(
1
,
level
,
0
,
0
,
1
,
0
)
elif
mode
==
"Y"
:
matrix
=
(
1
,
0
,
0
,
level
,
1
,
0
)
return
pil_img
.
transform
((
image_size
,
image_size
),
_pil_constants
.
AFFINE
,
matrix
,
resample
=
resample
)
return
pil_img
.
transform
((
image_size
,
image_size
),
PIL
.
Image
.
AFFINE
,
matrix
,
resample
=
resample
)
t_img
,
pil_img
=
_create_data
(
image_size
,
image_size
)
resample_pil
=
{
F
.
InterpolationMode
.
NEAREST
:
_pil_constants
.
NEAREST
,
F
.
InterpolationMode
.
BILINEAR
:
_pil_constants
.
BILINEAR
,
F
.
InterpolationMode
.
NEAREST
:
PIL
.
Image
.
NEAREST
,
F
.
InterpolationMode
.
BILINEAR
:
PIL
.
Image
.
BILINEAR
,
}[
interpolation
]
level
=
0.3
...
...
torchvision/transforms/_pil_constants.py
deleted
100644 → 0
View file @
70745705
from
PIL
import
Image
# See https://pillow.readthedocs.io/en/stable/releasenotes/9.1.0.html#deprecations
# TODO: Remove this file once PIL minimal version is >= 9.1
if
hasattr
(
Image
,
"Resampling"
):
BICUBIC
=
Image
.
Resampling
.
BICUBIC
BILINEAR
=
Image
.
Resampling
.
BILINEAR
LINEAR
=
Image
.
Resampling
.
BILINEAR
NEAREST
=
Image
.
Resampling
.
NEAREST
AFFINE
=
Image
.
Transform
.
AFFINE
FLIP_LEFT_RIGHT
=
Image
.
Transpose
.
FLIP_LEFT_RIGHT
FLIP_TOP_BOTTOM
=
Image
.
Transpose
.
FLIP_TOP_BOTTOM
PERSPECTIVE
=
Image
.
Transform
.
PERSPECTIVE
else
:
BICUBIC
=
Image
.
BICUBIC
BILINEAR
=
Image
.
BILINEAR
NEAREST
=
Image
.
NEAREST
LINEAR
=
Image
.
LINEAR
AFFINE
=
Image
.
AFFINE
FLIP_LEFT_RIGHT
=
Image
.
FLIP_LEFT_RIGHT
FLIP_TOP_BOTTOM
=
Image
.
FLIP_TOP_BOTTOM
PERSPECTIVE
=
Image
.
PERSPECTIVE
torchvision/transforms/functional.py
View file @
b570f2c1
...
...
@@ -421,6 +421,7 @@ def resize(
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
...
...
@@ -454,8 +455,12 @@ def resize(
if
not
torch
.
jit
.
is_scripting
()
and
not
torch
.
jit
.
is_tracing
():
_log_api_usage_once
(
resize
)
if
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode"
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
elif
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if
isinstance
(
size
,
(
list
,
tuple
)):
if
len
(
size
)
not
in
[
1
,
2
]:
...
...
@@ -630,6 +635,7 @@ def resized_crop(
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
...
...
@@ -726,6 +732,7 @@ def perspective(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
...
...
@@ -741,8 +748,12 @@ def perspective(
coeffs
=
_get_perspective_coeffs
(
startpoints
,
endpoints
)
if
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode"
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
elif
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
pil_interpolation
=
pil_modes_mapping
[
interpolation
]
...
...
@@ -1076,6 +1087,7 @@ def rotate(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
...
...
@@ -1097,15 +1109,19 @@ def rotate(
if
not
torch
.
jit
.
is_scripting
()
and
not
torch
.
jit
.
is_tracing
():
_log_api_usage_once
(
rotate
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
elif
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if
not
isinstance
(
angle
,
(
int
,
float
)):
raise
TypeError
(
"Argument angle should be int or float"
)
if
center
is
not
None
and
not
isinstance
(
center
,
(
list
,
tuple
)):
raise
TypeError
(
"Argument center should be a sequence"
)
if
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode"
)
if
not
isinstance
(
img
,
torch
.
Tensor
):
pil_interpolation
=
pil_modes_mapping
[
interpolation
]
return
F_pil
.
rotate
(
img
,
angle
=
angle
,
interpolation
=
pil_interpolation
,
expand
=
expand
,
center
=
center
,
fill
=
fill
)
...
...
@@ -1147,6 +1163,7 @@ def affine(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
...
...
@@ -1162,6 +1179,13 @@ def affine(
if
not
torch
.
jit
.
is_scripting
()
and
not
torch
.
jit
.
is_tracing
():
_log_api_usage_once
(
affine
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
elif
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)
if
not
isinstance
(
angle
,
(
int
,
float
)):
raise
TypeError
(
"Argument angle should be int or float"
)
...
...
@@ -1177,9 +1201,6 @@ def affine(
if
not
isinstance
(
shear
,
(
numbers
.
Number
,
(
list
,
tuple
))):
raise
TypeError
(
"Shear should be either a single value or a sequence of two values"
)
if
not
isinstance
(
interpolation
,
InterpolationMode
):
raise
TypeError
(
"Argument interpolation should be a InterpolationMode"
)
if
isinstance
(
angle
,
int
):
angle
=
float
(
angle
)
...
...
@@ -1524,7 +1545,7 @@ def elastic_transform(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``.
For backward compatibility integer values (
e.g. ``PIL.Image.NEAR
EST
``
)
are
still acceptable
.
The corresponding Pillow integer constants,
e.g. ``PIL.Image.
BILI
NEAR`` are
accepted as well
.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
...
...
torchvision/transforms/functional_pil.py
View file @
b570f2c1
...
...
@@ -9,7 +9,6 @@ try:
import
accimage
except
ImportError
:
accimage
=
None
from
.
import
_pil_constants
@
torch
.
jit
.
unused
...
...
@@ -54,7 +53,7 @@ def hflip(img: Image.Image) -> Image.Image:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
transpose
(
_pil_constants
.
FLIP_LEFT_RIGHT
)
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
@
torch
.
jit
.
unused
...
...
@@ -62,7 +61,7 @@ def vflip(img: Image.Image) -> Image.Image:
if
not
_is_pil_image
(
img
):
raise
TypeError
(
f
"img should be PIL Image. Got
{
type
(
img
)
}
"
)
return
img
.
transpose
(
_pil_constants
.
FLIP_TOP_BOTTOM
)
return
img
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
@
torch
.
jit
.
unused
...
...
@@ -240,7 +239,7 @@ def crop(
def
resize
(
img
:
Image
.
Image
,
size
:
Union
[
List
[
int
],
int
],
interpolation
:
int
=
_pil_constants
.
BILINEAR
,
interpolation
:
int
=
Image
.
BILINEAR
,
)
->
Image
.
Image
:
if
not
_is_pil_image
(
img
):
...
...
@@ -284,7 +283,7 @@ def _parse_fill(
def
affine
(
img
:
Image
.
Image
,
matrix
:
List
[
float
],
interpolation
:
int
=
_pil_constants
.
NEAREST
,
interpolation
:
int
=
Image
.
NEAREST
,
fill
:
Optional
[
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
]]]
=
None
,
)
->
Image
.
Image
:
...
...
@@ -293,14 +292,14 @@ def affine(
output_size
=
img
.
size
opts
=
_parse_fill
(
fill
,
img
)
return
img
.
transform
(
output_size
,
_pil_constants
.
AFFINE
,
matrix
,
interpolation
,
**
opts
)
return
img
.
transform
(
output_size
,
Image
.
AFFINE
,
matrix
,
interpolation
,
**
opts
)
@
torch
.
jit
.
unused
def
rotate
(
img
:
Image
.
Image
,
angle
:
float
,
interpolation
:
int
=
_pil_constants
.
NEAREST
,
interpolation
:
int
=
Image
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
fill
:
Optional
[
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
]]]
=
None
,
...
...
@@ -317,7 +316,7 @@ def rotate(
def
perspective
(
img
:
Image
.
Image
,
perspective_coeffs
:
List
[
float
],
interpolation
:
int
=
_pil_constants
.
BICUBIC
,
interpolation
:
int
=
Image
.
BICUBIC
,
fill
:
Optional
[
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
]]]
=
None
,
)
->
Image
.
Image
:
...
...
@@ -326,7 +325,7 @@ def perspective(
opts
=
_parse_fill
(
fill
,
img
)
return
img
.
transform
(
img
.
size
,
_pil_constants
.
PERSPECTIVE
,
perspective_coeffs
,
interpolation
,
**
opts
)
return
img
.
transform
(
img
.
size
,
Image
.
PERSPECTIVE
,
perspective_coeffs
,
interpolation
,
**
opts
)
@
torch
.
jit
.
unused
...
...
torchvision/transforms/transforms.py
View file @
b570f2c1
...
...
@@ -298,6 +298,7 @@ class Resize(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
...
...
@@ -336,6 +337,9 @@ class Resize(torch.nn.Module):
self
.
size
=
size
self
.
max_size
=
max_size
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
self
.
interpolation
=
interpolation
self
.
antialias
=
antialias
...
...
@@ -756,6 +760,7 @@ class RandomPerspective(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
"""
...
...
@@ -765,6 +770,9 @@ class RandomPerspective(torch.nn.Module):
_log_api_usage_once
(
self
)
self
.
p
=
p
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
self
.
interpolation
=
interpolation
self
.
distortion_scale
=
distortion_scale
...
...
@@ -861,6 +869,7 @@ class RandomResizedCrop(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
...
...
@@ -900,9 +909,11 @@ class RandomResizedCrop(torch.nn.Module):
if
(
scale
[
0
]
>
scale
[
1
])
or
(
ratio
[
0
]
>
ratio
[
1
]):
warnings
.
warn
(
"Scale and ratio should be of kind (min, max)"
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
self
.
interpolation
=
interpolation
self
.
antialias
=
antialias
self
.
interpolation
=
interpolation
self
.
scale
=
scale
self
.
ratio
=
ratio
...
...
@@ -1139,10 +1150,10 @@ class LinearTransformation(torch.nn.Module):
)
flat_tensor
=
tensor
.
view
(
-
1
,
n
)
-
self
.
mean_vector
transformation_matrix
=
self
.
transformation_matrix
.
to
(
flat_tensor
.
dtype
)
transformed_tensor
=
torch
.
mm
(
flat_tensor
,
transformation_matrix
)
return
transformed_tensor
.
view
(
shape
)
tensor
=
transformed_tensor
.
view
(
shape
)
return
tensor
def
__repr__
(
self
)
->
str
:
s
=
(
...
...
@@ -1293,6 +1304,7 @@ class RandomRotation(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
...
...
@@ -1310,6 +1322,9 @@ class RandomRotation(torch.nn.Module):
super
().
__init__
()
_log_api_usage_once
(
self
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
if
center
is
not
None
:
...
...
@@ -1393,6 +1408,7 @@ class RandomAffine(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
...
...
@@ -1415,6 +1431,9 @@ class RandomAffine(torch.nn.Module):
super
().
__init__
()
_log_api_usage_once
(
self
)
if
isinstance
(
interpolation
,
int
):
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
self
.
degrees
=
_setup_angle
(
degrees
,
name
=
"degrees"
,
req_sizes
=
(
2
,))
if
translate
is
not
None
:
...
...
@@ -2039,7 +2058,7 @@ class ElasticTransform(torch.nn.Module):
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (
e.g. ``PIL.Image.NEAR
EST
``
)
are
still acceptable
.
The corresponding Pillow integer constants,
e.g. ``PIL.Image.
BILI
NEAR`` are
accepted as well
.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
...
...
@@ -2080,12 +2099,7 @@ class ElasticTransform(torch.nn.Module):
self
.
sigma
=
sigma
# Backward compatibility with integer value
if
isinstance
(
interpolation
,
int
):
warnings
.
warn
(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation
=
_interpolation_modes_from_int
(
interpolation
)
self
.
interpolation
=
interpolation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment