Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ca012d39
"vscode:/vscode.git/clone" did not exist on "824cb538b1142e0ca9b7df94e5e6ee100e996109"
Unverified
Commit
ca012d39
authored
Aug 16, 2023
by
Philip Meier
Committed by
GitHub
Aug 16, 2023
Browse files
make PIL kernels private (#7831)
parent
cdbbd666
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
329 additions
and
355 deletions
+329
-355
docs/source/transforms.rst
docs/source/transforms.rst
+1
-2
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+2
-2
references/detection/presets.py
references/detection/presets.py
+4
-4
references/segmentation/presets.py
references/segmentation/presets.py
+3
-3
test/common_utils.py
test/common_utils.py
+5
-5
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+2
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+6
-23
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+8
-8
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+14
-14
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+54
-56
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+2
-2
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+41
-41
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+45
-46
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+2
-2
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-1
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+1
-1
torchvision/transforms/v2/_type_conversion.py
torchvision/transforms/v2/_type_conversion.py
+4
-9
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+64
-64
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+5
-5
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+65
-65
No files found.
docs/source/transforms.rst
View file @
ca012d39
...
...
@@ -228,12 +228,11 @@ Conversion
ToPILImage
v2.ToPILImage
v2.ToImagePIL
ToTensor
v2.ToTensor
PILToTensor
v2.PILToTensor
v2.ToImage
Tensor
v2.ToImage
ConvertImageDtype
v2.ConvertImageDtype
v2.ToDtype
...
...
gallery/plot_transforms_v2_e2e.py
View file @
ca012d39
...
...
@@ -27,7 +27,7 @@ def show(sample):
image
,
target
=
sample
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
image
=
F
.
to_image
_tensor
(
image
)
image
=
F
.
to_image
(
image
)
image
=
F
.
to_dtype
(
image
,
torch
.
uint8
,
scale
=
True
)
annotated_image
=
draw_bounding_boxes
(
image
,
target
[
"boxes"
],
colors
=
"yellow"
,
width
=
3
)
...
...
@@ -101,7 +101,7 @@ transform = transforms.Compose(
transforms
.
RandomZoomOut
(
fill
=
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToImage
Tensor
(),
transforms
.
ToImage
(),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
transforms
.
SanitizeBoundingBoxes
(),
]
...
...
references/detection/presets.py
View file @
ca012d39
...
...
@@ -33,7 +33,7 @@ class DetectionPresetTrain:
transforms
=
[]
backend
=
backend
.
lower
()
if
backend
==
"datapoint"
:
transforms
.
append
(
T
.
ToImage
Tensor
())
transforms
.
append
(
T
.
ToImage
())
elif
backend
==
"tensor"
:
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
...
...
@@ -71,7 +71,7 @@ class DetectionPresetTrain:
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2.
transforms
+=
[
T
.
ToImage
Tensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ToImage
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
)]
...
...
@@ -94,11 +94,11 @@ class DetectionPresetEval:
backend
=
backend
.
lower
()
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2?
transforms
+=
[
T
.
ToImage
Tensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ToImage
()
if
use_v2
else
T
.
PILToTensor
()]
elif
backend
==
"tensor"
:
transforms
+=
[
T
.
PILToTensor
()]
elif
backend
==
"datapoint"
:
transforms
+=
[
T
.
ToImage
Tensor
()]
transforms
+=
[
T
.
ToImage
()]
else
:
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
...
...
references/segmentation/presets.py
View file @
ca012d39
...
...
@@ -32,7 +32,7 @@ class SegmentationPresetTrain:
transforms
=
[]
backend
=
backend
.
lower
()
if
backend
==
"datapoint"
:
transforms
.
append
(
T
.
ToImage
Tensor
())
transforms
.
append
(
T
.
ToImage
())
elif
backend
==
"tensor"
:
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
...
...
@@ -81,7 +81,7 @@ class SegmentationPresetEval:
if
backend
==
"tensor"
:
transforms
+=
[
T
.
PILToTensor
()]
elif
backend
==
"datapoint"
:
transforms
+=
[
T
.
ToImage
Tensor
()]
transforms
+=
[
T
.
ToImage
()]
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
...
...
@@ -92,7 +92,7 @@ class SegmentationPresetEval:
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2?
transforms
+=
[
T
.
ToImage
Tensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ToImage
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
),
...
...
test/common_utils.py
View file @
ca012d39
...
...
@@ -27,7 +27,7 @@ from PIL import Image
from
torch.testing._comparison
import
BooleanPair
,
NonePair
,
not_close_error_metas
,
NumberPair
,
TensorLikePair
from
torchvision
import
datapoints
,
io
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.v2.functional
import
to_dtype_image
_tensor
,
to_image
_pil
,
to_image
_tensor
from
torchvision.transforms.v2.functional
import
to_dtype_image
,
to_image
,
to_
pil_
image
IN_OSS_CI
=
any
(
os
.
getenv
(
var
)
==
"true"
for
var
in
[
"CIRCLECI"
,
"GITHUB_ACTIONS"
])
...
...
@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair):
**
other_parameters
,
):
if
all
(
isinstance
(
input
,
PIL
.
Image
.
Image
)
for
input
in
[
actual
,
expected
]):
actual
,
expected
=
[
to_image
_tensor
(
input
)
for
input
in
[
actual
,
expected
]]
actual
,
expected
=
[
to_image
(
input
)
for
input
in
[
actual
,
expected
]]
super
().
__init__
(
actual
,
expected
,
**
other_parameters
)
self
.
mae
=
mae
...
...
@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs):
def
make_image_pil
(
*
args
,
**
kwargs
):
return
to_image
_pil
(
make_image
(
*
args
,
**
kwargs
))
return
to_
pil_
image
(
make_image
(
*
args
,
**
kwargs
))
def
make_image_loader
(
...
...
@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation(
)
)
image_tensor
=
to_image
_tensor
(
image_pil
)
image_tensor
=
to_image
(
image_pil
)
if
memory_format
==
torch
.
contiguous_format
:
image_tensor
=
image_tensor
.
to
(
device
=
device
,
memory_format
=
memory_format
,
copy
=
True
)
else
:
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
to_dtype_image
_tensor
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
image_tensor
=
to_dtype_image
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
return
datapoints
.
Image
(
image_tensor
)
...
...
test/test_prototype_transforms.py
View file @
ca012d39
...
...
@@ -17,7 +17,7 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_image
_pil
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_
pil_
image
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
...
@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference():
size
=
(
600
,
800
)
num_objects
=
22
pil_image
=
to_image
_pil
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
pil_image
=
to_
pil_
image
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
target
=
{
"boxes"
:
make_bounding_box
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
...
test/test_transforms_v2.py
View file @
ca012d39
...
...
@@ -666,19 +666,19 @@ class TestTransform:
t
(
inpt
)
class
TestToImage
Tensor
:
class
TestToImage
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image
_tensor
"
,
"torchvision.transforms.v2.functional.to_image"
,
return_value
=
torch
.
rand
(
1
,
3
,
8
,
8
),
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImage
Tensor
()
transform
=
transforms
.
ToImage
()
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
BoundingBoxes
,
datapoints
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
...
...
@@ -686,30 +686,13 @@ class TestToImageTensor:
fn
.
assert_called_once_with
(
inpt
)
class
TestToImagePIL
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image_pil"
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImagePIL
()
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
BoundingBoxes
,
PIL
.
Image
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
else
:
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
class
TestToPILImage
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image
_pil
"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_
pil_
image"
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToPILImage
()
...
...
@@ -1013,7 +996,7 @@ def test_antialias_warning():
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
Tensor
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
image
=
datapoints
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
...
...
@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
Tensor
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
def
test_detection_preset
(
image_type
,
data_augmentation
,
to_tensor
,
sanitize
):
torch
.
manual_seed
(
0
)
...
...
test/test_transforms_v2_consistency.py
View file @
ca012d39
...
...
@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2._utils
import
_get_fill
from
torchvision.transforms.v2.functional
import
to_image
_pil
from
torchvision.transforms.v2.functional
import
to_
pil_
image
from
torchvision.transforms.v2.utils
import
query_size
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"RGB"
],
extra_dims
=
[(
4
,)])
...
...
@@ -630,7 +630,7 @@ def check_call_consistency(
)
if
image
.
ndim
==
3
and
supports_pil
:
image_pil
=
to_image
_pil
(
image
)
image_pil
=
to_
pil_
image
(
image
)
try
:
torch
.
manual_seed
(
0
)
...
...
@@ -869,7 +869,7 @@ class TestToTensorTransforms:
legacy_transform
=
legacy_transforms
.
PILToTensor
()
for
image
in
make_images
(
extra_dims
=
[()]):
image_pil
=
to_image
_pil
(
image
)
image_pil
=
to_
pil_
image
(
image
)
assert_equal
(
prototype_transform
(
image_pil
),
legacy_transform
(
image_pil
))
...
...
@@ -879,7 +879,7 @@ class TestToTensorTransforms:
legacy_transform
=
legacy_transforms
.
ToTensor
()
for
image
in
make_images
(
extra_dims
=
[()]):
image_pil
=
to_image
_pil
(
image
)
image_pil
=
to_
pil_
image
(
image
)
image_numpy
=
np
.
array
(
image_pil
)
assert_equal
(
prototype_transform
(
image_pil
),
legacy_transform
(
image_pil
))
...
...
@@ -1088,7 +1088,7 @@ class TestRefDetTransforms:
def
make_label
(
extra_dims
,
categories
):
return
torch
.
randint
(
categories
,
extra_dims
,
dtype
=
torch
.
int64
)
pil_image
=
to_image
_pil
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
pil_image
=
to_
pil_
image
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
target
=
{
"boxes"
:
make_bounding_box
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
conv_fns
=
[]
if
supports_pil
:
conv_fns
.
append
(
to_image
_pil
)
conv_fns
.
append
(
to_
pil_
image
)
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
...
...
@@ -1201,8 +1201,8 @@ class TestRefSegTransforms:
dp
=
(
conv_fn
(
datapoint_image
),
datapoint_mask
)
dp_ref
=
(
to_image
_pil
(
datapoint_image
)
if
supports_pil
else
datapoint_image
.
as_subclass
(
torch
.
Tensor
),
to_image
_pil
(
datapoint_mask
),
to_
pil_
image
(
datapoint_image
)
if
supports_pil
else
datapoint_image
.
as_subclass
(
torch
.
Tensor
),
to_
pil_
image
(
datapoint_mask
),
)
yield
dp
,
dp_ref
...
...
test/test_transforms_v2_functional.py
View file @
ca012d39
...
...
@@ -280,12 +280,12 @@ class TestKernels:
adapted_other_args
,
adapted_kwargs
=
info
.
float32_vs_uint8
(
other_args
,
kwargs
)
actual
=
info
.
kernel
(
F
.
to_dtype_image
_tensor
(
input
,
dtype
=
torch
.
float32
,
scale
=
True
),
F
.
to_dtype_image
(
input
,
dtype
=
torch
.
float32
,
scale
=
True
),
*
adapted_other_args
,
**
adapted_kwargs
,
)
expected
=
F
.
to_dtype_image
_tensor
(
info
.
kernel
(
input
,
*
other_args
,
**
kwargs
),
dtype
=
torch
.
float32
,
scale
=
True
)
expected
=
F
.
to_dtype_image
(
info
.
kernel
(
input
,
*
other_args
,
**
kwargs
),
dtype
=
torch
.
float32
,
scale
=
True
)
assert_close
(
actual
,
...
...
@@ -377,7 +377,7 @@ class TestDispatchers:
if
image_datapoint
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_image
_pil
(
image_datapoint
)
image_pil
=
F
.
to_
pil_
image
(
image_datapoint
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
...
...
@@ -470,7 +470,7 @@ class TestDispatchers:
(
F
.
hflip
,
F
.
horizontal_flip
),
(
F
.
vflip
,
F
.
vertical_flip
),
(
F
.
get_image_num_channels
,
F
.
get_num_channels
),
(
F
.
to_pil_image
,
F
.
to_image
_pil
),
(
F
.
to_pil_image
,
F
.
to_
pil_
image
),
(
F
.
elastic_transform
,
F
.
elastic
),
(
F
.
to_grayscale
,
F
.
rgb_to_grayscale
),
]
...
...
@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
mean
=
image
.
mean
(
dim
=
(
1
,
2
)).
tolist
()
std
=
image
.
std
(
dim
=
(
1
,
2
)).
tolist
()
assert_samples_from_standard_normal
(
F
.
normalize_image
_tensor
(
image
,
mean
,
std
))
assert_samples_from_standard_normal
(
F
.
normalize_image
(
image
,
mean
,
std
))
class
TestClampBoundingBoxes
:
...
...
@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size):
_
,
image_height
,
image_width
=
mask
.
shape
if
crop_width
>
image_height
or
crop_height
>
image_width
:
padding
=
_center_crop_compute_padding
(
crop_height
,
crop_width
,
image_height
,
image_width
)
mask
=
F
.
pad_image
_tensor
(
mask
,
padding
,
fill
=
0
)
mask
=
F
.
pad_image
(
mask
,
padding
,
fill
=
0
)
left
=
round
((
image_width
-
crop_width
)
*
0.5
)
top
=
round
((
image_height
-
crop_height
)
*
0.5
)
...
...
@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size):
@
pytest
.
mark
.
parametrize
(
"ksize"
,
[(
3
,
3
),
[
3
,
5
],
(
23
,
23
)])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[[
0.5
,
0.5
],
(
0.5
,
0.5
),
(
0.8
,
0.8
),
(
1.7
,
1.7
)])
def
test_correctness_gaussian_blur_image_tensor
(
device
,
canvas_size
,
dt
,
ksize
,
sigma
):
fn
=
F
.
gaussian_blur_image
_tensor
fn
=
F
.
gaussian_blur_image
# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
...
...
@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
PIL
.
Image
.
new
(
"RGB"
,
(
32
,
32
),
122
),
],
)
def
test_to_image
_tensor
(
inpt
):
output
=
F
.
to_image
_tensor
(
inpt
)
def
test_to_image
(
inpt
):
output
=
F
.
to_image
(
inpt
)
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
output
.
shape
==
(
3
,
32
,
32
)
...
...
@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt):
],
)
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
None
,
"RGB"
])
def
test_to_image
_pil
(
inpt
,
mode
):
output
=
F
.
to_image
_pil
(
inpt
,
mode
=
mode
)
def
test_to_
pil_
image
(
inpt
,
mode
):
output
=
F
.
to_
pil_
image
(
inpt
,
mode
=
mode
)
assert
isinstance
(
output
,
PIL
.
Image
.
Image
)
assert
np
.
asarray
(
inpt
).
sum
()
==
np
.
asarray
(
output
).
sum
()
...
...
@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode):
def
test_equalize_image_tensor_edge_cases
():
inpt
=
torch
.
zeros
(
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
output
=
F
.
equalize_image
_tensor
(
inpt
)
output
=
F
.
equalize_image
(
inpt
)
torch
.
testing
.
assert_close
(
inpt
,
output
)
inpt
=
torch
.
zeros
(
5
,
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
inpt
[...,
100
:,
100
:]
=
1
output
=
F
.
equalize_image
_tensor
(
inpt
)
output
=
F
.
equalize_image
(
inpt
)
assert
output
.
unique
().
tolist
()
==
[
0
,
255
]
...
...
@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device):
# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
KERNEL_INFOS
if
info
.
kernel
is
F
.
resize_image
_tensor
],
[
info
for
info
in
KERNEL_INFOS
if
info
.
kernel
is
F
.
resize_image
],
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
)
def
test_memory_format_consistency_resize_image_tensor
(
test_id
,
info
,
args_kwargs
):
...
...
test/test_transforms_v2_refactored.py
View file @
ca012d39
...
...
@@ -437,7 +437,7 @@ class TestResize:
check_cuda_vs_cpu_tolerances
=
dict
(
rtol
=
0
,
atol
=
atol
/
255
if
dtype
.
is_floating_point
else
atol
)
check_kernel
(
F
.
resize_image
_tensor
,
F
.
resize_image
,
make_image
(
self
.
INPUT_SIZE
,
dtype
=
dtype
,
device
=
device
),
size
=
size
,
interpolation
=
interpolation
,
...
...
@@ -495,9 +495,9 @@ class TestResize:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
resize_image
_tensor
,
torch
.
Tensor
),
(
F
.
resize_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
resize_image
_tensor
,
datapoints
.
Image
),
(
F
.
resize_image
,
torch
.
Tensor
),
(
F
.
_
resize_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
resize_image
,
datapoints
.
Image
),
(
F
.
resize_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
resize_mask
,
datapoints
.
Mask
),
(
F
.
resize_video
,
datapoints
.
Video
),
...
...
@@ -541,9 +541,7 @@ class TestResize:
image
=
make_image
(
self
.
INPUT_SIZE
,
dtype
=
torch
.
uint8
)
actual
=
fn
(
image
,
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
,
antialias
=
True
)
expected
=
F
.
to_image_tensor
(
F
.
resize
(
F
.
to_image_pil
(
image
),
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
)
)
expected
=
F
.
to_image
(
F
.
resize
(
F
.
to_pil_image
(
image
),
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
))
self
.
_check_output_size
(
image
,
actual
,
size
=
size
,
**
max_size_kwarg
)
torch
.
testing
.
assert_close
(
actual
,
expected
,
atol
=
1
,
rtol
=
0
)
...
...
@@ -739,7 +737,7 @@ class TestHorizontalFlip:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
check_kernel
(
F
.
horizontal_flip_image
_tensor
,
make_image
(
dtype
=
dtype
,
device
=
device
))
check_kernel
(
F
.
horizontal_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoints
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
...
...
@@ -770,9 +768,9 @@ class TestHorizontalFlip:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
horizontal_flip_image
_tensor
,
torch
.
Tensor
),
(
F
.
horizontal_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
horizontal_flip_image
_tensor
,
datapoints
.
Image
),
(
F
.
horizontal_flip_image
,
torch
.
Tensor
),
(
F
.
_
horizontal_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
horizontal_flip_image
,
datapoints
.
Image
),
(
F
.
horizontal_flip_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
horizontal_flip_mask
,
datapoints
.
Mask
),
(
F
.
horizontal_flip_video
,
datapoints
.
Video
),
...
...
@@ -796,7 +794,7 @@ class TestHorizontalFlip:
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
fn
(
image
)
expected
=
F
.
to_image
_tensor
(
F
.
horizontal_flip
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
F
.
horizontal_flip
(
F
.
to_
pil_
image
(
image
)))
torch
.
testing
.
assert_close
(
actual
,
expected
)
...
...
@@ -900,7 +898,7 @@ class TestAffine:
if
param
==
"fill"
:
value
=
adapt_fill
(
value
,
dtype
=
dtype
)
self
.
_check_kernel
(
F
.
affine_image
_tensor
,
F
.
affine_image
,
make_image
(
dtype
=
dtype
,
device
=
device
),
**
{
param
:
value
},
check_scripted_vs_eager
=
not
(
param
in
{
"shear"
,
"fill"
}
and
isinstance
(
value
,
(
int
,
float
))),
...
...
@@ -946,9 +944,9 @@ class TestAffine:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
affine_image
_tensor
,
torch
.
Tensor
),
(
F
.
affine_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
affine_image
_tensor
,
datapoints
.
Image
),
(
F
.
affine_image
,
torch
.
Tensor
),
(
F
.
_
affine_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
affine_image
,
datapoints
.
Image
),
(
F
.
affine_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
affine_mask
,
datapoints
.
Mask
),
(
F
.
affine_video
,
datapoints
.
Video
),
...
...
@@ -991,9 +989,9 @@ class TestAffine:
interpolation
=
interpolation
,
fill
=
fill
,
)
expected
=
F
.
to_image
_tensor
(
expected
=
F
.
to_image
(
F
.
affine
(
F
.
to_image
_pil
(
image
),
F
.
to_
pil_
image
(
image
),
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
...
...
@@ -1026,7 +1024,7 @@ class TestAffine:
actual
=
transform
(
image
)
torch
.
manual_seed
(
seed
)
expected
=
F
.
to_image
_tensor
(
transform
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
transform
(
F
.
to_
pil_
image
(
image
)))
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
2
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
8
...
...
@@ -1204,7 +1202,7 @@ class TestVerticalFlip:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
check_kernel
(
F
.
vertical_flip_image
_tensor
,
make_image
(
dtype
=
dtype
,
device
=
device
))
check_kernel
(
F
.
vertical_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoints
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
...
...
@@ -1235,9 +1233,9 @@ class TestVerticalFlip:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
vertical_flip_image
_tensor
,
torch
.
Tensor
),
(
F
.
vertical_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
vertical_flip_image
_tensor
,
datapoints
.
Image
),
(
F
.
vertical_flip_image
,
torch
.
Tensor
),
(
F
.
_
vertical_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
vertical_flip_image
,
datapoints
.
Image
),
(
F
.
vertical_flip_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
vertical_flip_mask
,
datapoints
.
Mask
),
(
F
.
vertical_flip_video
,
datapoints
.
Video
),
...
...
@@ -1259,7 +1257,7 @@ class TestVerticalFlip:
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
fn
(
image
)
expected
=
F
.
to_image
_tensor
(
F
.
vertical_flip
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
F
.
vertical_flip
(
F
.
to_
pil_
image
(
image
)))
torch
.
testing
.
assert_close
(
actual
,
expected
)
...
...
@@ -1339,7 +1337,7 @@ class TestRotate:
if
param
!=
"angle"
:
kwargs
[
"angle"
]
=
self
.
_MINIMAL_AFFINE_KWARGS
[
"angle"
]
check_kernel
(
F
.
rotate_image
_tensor
,
F
.
rotate_image
,
make_image
(
dtype
=
dtype
,
device
=
device
),
**
kwargs
,
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
...
...
@@ -1385,9 +1383,9 @@ class TestRotate:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
rotate_image
_tensor
,
torch
.
Tensor
),
(
F
.
rotate_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
rotate_image
_tensor
,
datapoints
.
Image
),
(
F
.
rotate_image
,
torch
.
Tensor
),
(
F
.
_
rotate_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
rotate_image
,
datapoints
.
Image
),
(
F
.
rotate_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
rotate_mask
,
datapoints
.
Mask
),
(
F
.
rotate_video
,
datapoints
.
Video
),
...
...
@@ -1419,9 +1417,9 @@ class TestRotate:
fill
=
adapt_fill
(
fill
,
dtype
=
torch
.
uint8
)
actual
=
F
.
rotate
(
image
,
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
)
expected
=
F
.
to_image
_tensor
(
expected
=
F
.
to_image
(
F
.
rotate
(
F
.
to_image
_pil
(
image
),
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
F
.
to_
pil_
image
(
image
),
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
)
)
...
...
@@ -1452,7 +1450,7 @@ class TestRotate:
actual
=
transform
(
image
)
torch
.
manual_seed
(
seed
)
expected
=
F
.
to_image
_tensor
(
transform
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
transform
(
F
.
to_
pil_
image
(
image
)))
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
1
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
6
...
...
@@ -1621,8 +1619,8 @@ class TestToDtype:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
to_dtype_image
_tensor
,
make_image_tensor
),
(
F
.
to_dtype_image
_tensor
,
make_image
),
(
F
.
to_dtype_image
,
make_image_tensor
),
(
F
.
to_dtype_image
,
make_image
),
(
F
.
to_dtype_video
,
make_video
),
],
)
...
...
@@ -1801,7 +1799,7 @@ class TestAdjustBrightness:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
adjust_brightness_image
_tensor
,
make_image
),
(
F
.
adjust_brightness_image
,
make_image
),
(
F
.
adjust_brightness_video
,
make_video
),
],
)
...
...
@@ -1817,9 +1815,9 @@ class TestAdjustBrightness:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
adjust_brightness_image
_tensor
,
torch
.
Tensor
),
(
F
.
adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
adjust_brightness_image
_tensor
,
datapoints
.
Image
),
(
F
.
adjust_brightness_image
,
torch
.
Tensor
),
(
F
.
_
adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
adjust_brightness_image
,
datapoints
.
Image
),
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
],
)
...
...
@@ -1831,7 +1829,7 @@ class TestAdjustBrightness:
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
F
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
expected
=
F
.
to_image
_tensor
(
F
.
adjust_brightness
(
F
.
to_image
_pil
(
image
),
brightness_factor
=
brightness_factor
))
expected
=
F
.
to_image
(
F
.
adjust_brightness
(
F
.
to_
pil_
image
(
image
),
brightness_factor
=
brightness_factor
))
torch
.
testing
.
assert_close
(
actual
,
expected
)
...
...
@@ -1979,9 +1977,9 @@ class TestShapeGetters:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_dimensions_image
_tensor
,
make_image_tensor
),
(
F
.
get_dimensions_image_pil
,
make_image_pil
),
(
F
.
get_dimensions_image
_tensor
,
make_image
),
(
F
.
get_dimensions_image
,
make_image_tensor
),
(
F
.
_
get_dimensions_image_pil
,
make_image_pil
),
(
F
.
get_dimensions_image
,
make_image
),
(
F
.
get_dimensions_video
,
make_video
),
],
)
...
...
@@ -1996,9 +1994,9 @@ class TestShapeGetters:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_num_channels_image
_tensor
,
make_image_tensor
),
(
F
.
get_num_channels_image_pil
,
make_image_pil
),
(
F
.
get_num_channels_image
_tensor
,
make_image
),
(
F
.
get_num_channels_image
,
make_image_tensor
),
(
F
.
_
get_num_channels_image_pil
,
make_image_pil
),
(
F
.
get_num_channels_image
,
make_image
),
(
F
.
get_num_channels_video
,
make_video
),
],
)
...
...
@@ -2012,9 +2010,9 @@ class TestShapeGetters:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_size_image
_tensor
,
make_image_tensor
),
(
F
.
get_size_image_pil
,
make_image_pil
),
(
F
.
get_size_image
_tensor
,
make_image
),
(
F
.
get_size_image
,
make_image_tensor
),
(
F
.
_
get_size_image_pil
,
make_image_pil
),
(
F
.
get_size_image
,
make_image
),
(
F
.
get_size_bounding_boxes
,
make_bounding_box
),
(
F
.
get_size_mask
,
make_detection_mask
),
(
F
.
get_size_mask
,
make_segmentation_mask
),
...
...
@@ -2101,7 +2099,7 @@ class TestRegisterKernel:
F
.
register_kernel
(
F
.
resize
,
object
)
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be registered for the builtin datapoint classes"
):
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image
_tensor
)
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image
)
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
...
...
@@ -2119,9 +2117,9 @@ class TestGetKernel:
# We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# would also be fine
KERNELS
=
{
torch
.
Tensor
:
F
.
resize_image
_tensor
,
PIL
.
Image
.
Image
:
F
.
resize_image_pil
,
datapoints
.
Image
:
F
.
resize_image
_tensor
,
torch
.
Tensor
:
F
.
resize_image
,
PIL
.
Image
.
Image
:
F
.
_
resize_image_pil
,
datapoints
.
Image
:
F
.
resize_image
,
datapoints
.
BoundingBoxes
:
F
.
resize_bounding_boxes
,
datapoints
.
Mask
:
F
.
resize_mask
,
datapoints
.
Video
:
F
.
resize_video
,
...
...
@@ -2217,10 +2215,10 @@ class TestPermuteChannels:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
permute_channels_image
_tensor
,
make_image_tensor
),
(
F
.
permute_channels_image
,
make_image_tensor
),
# FIXME
# check_kernel does not support PIL kernel, but it should
(
F
.
permute_channels_image
_tensor
,
make_image
),
(
F
.
permute_channels_image
,
make_image
),
(
F
.
permute_channels_video
,
make_video
),
],
)
...
...
@@ -2236,9 +2234,9 @@ class TestPermuteChannels:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
permute_channels_image
_tensor
,
torch
.
Tensor
),
(
F
.
permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
permute_channels_image
_tensor
,
datapoints
.
Image
),
(
F
.
permute_channels_image
,
torch
.
Tensor
),
(
F
.
_
permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
permute_channels_image
,
datapoints
.
Image
),
(
F
.
permute_channels_video
,
datapoints
.
Video
),
],
)
...
...
test/test_transforms_v2_utils.py
View file @
ca012d39
...
...
@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_box
,
make_detection_mask
,
make_image
from
torchvision
import
datapoints
from
torchvision.transforms.v2.functional
import
to_image
_pil
from
torchvision.transforms.v2.functional
import
to_
pil_
image
from
torchvision.transforms.v2.utils
import
has_all
,
has_any
...
...
@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE)
True
,
),
(
(
to_image
_pil
(
IMAGE
),),
(
to_
pil_
image
(
IMAGE
),),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_simple_tensor
),
True
,
),
...
...
test/transforms_v2_dispatcher_infos.py
View file @
ca012d39
...
...
@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
crop
,
kernels
=
{
datapoints
.
Image
:
F
.
crop_image
_tensor
,
datapoints
.
Image
:
F
.
crop_image
,
datapoints
.
Video
:
F
.
crop_video
,
datapoints
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
datapoints
.
Mask
:
F
.
crop_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
),
DispatcherInfo
(
F
.
resized_crop
,
kernels
=
{
datapoints
.
Image
:
F
.
resized_crop_image
_tensor
,
datapoints
.
Image
:
F
.
resized_crop_image
,
datapoints
.
Video
:
F
.
resized_crop_video
,
datapoints
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
datapoints
.
Mask
:
F
.
resized_crop_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
resized_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
resized_crop_image_pil
),
),
DispatcherInfo
(
F
.
pad
,
kernels
=
{
datapoints
.
Image
:
F
.
pad_image
_tensor
,
datapoints
.
Image
:
F
.
pad_image
,
datapoints
.
Video
:
F
.
pad_video
,
datapoints
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
datapoints
.
Mask
:
F
.
pad_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
test_marks
=
[
*
xfails_pil
(
reason
=
(
...
...
@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
perspective
,
kernels
=
{
datapoints
.
Image
:
F
.
perspective_image
_tensor
,
datapoints
.
Image
:
F
.
perspective_image
,
datapoints
.
Video
:
F
.
perspective_video
,
datapoints
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
datapoints
.
Mask
:
F
.
perspective_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
perspective_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
perspective_image_pil
),
test_marks
=
[
*
xfails_pil_if_fill_sequence_needs_broadcast
,
xfail_jit_python_scalar_arg
(
"fill"
),
...
...
@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
elastic
,
kernels
=
{
datapoints
.
Image
:
F
.
elastic_image
_tensor
,
datapoints
.
Image
:
F
.
elastic_image
,
datapoints
.
Video
:
F
.
elastic_video
,
datapoints
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
datapoints
.
Mask
:
F
.
elastic_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
elastic_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
elastic_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
),
DispatcherInfo
(
F
.
center_crop
,
kernels
=
{
datapoints
.
Image
:
F
.
center_crop_image
_tensor
,
datapoints
.
Image
:
F
.
center_crop_image
,
datapoints
.
Video
:
F
.
center_crop_video
,
datapoints
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
datapoints
.
Mask
:
F
.
center_crop_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
center_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
center_crop_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"output_size"
),
],
...
...
@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
gaussian_blur
,
kernels
=
{
datapoints
.
Image
:
F
.
gaussian_blur_image
_tensor
,
datapoints
.
Image
:
F
.
gaussian_blur_image
,
datapoints
.
Video
:
F
.
gaussian_blur_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
gaussian_blur_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
gaussian_blur_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"kernel_size"
),
xfail_jit_python_scalar_arg
(
"sigma"
),
...
...
@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
equalize
,
kernels
=
{
datapoints
.
Image
:
F
.
equalize_image
_tensor
,
datapoints
.
Image
:
F
.
equalize_image
,
datapoints
.
Video
:
F
.
equalize_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
),
DispatcherInfo
(
F
.
invert
,
kernels
=
{
datapoints
.
Image
:
F
.
invert_image
_tensor
,
datapoints
.
Image
:
F
.
invert_image
,
datapoints
.
Video
:
F
.
invert_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
),
DispatcherInfo
(
F
.
posterize
,
kernels
=
{
datapoints
.
Image
:
F
.
posterize_image
_tensor
,
datapoints
.
Image
:
F
.
posterize_image
,
datapoints
.
Video
:
F
.
posterize_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
),
DispatcherInfo
(
F
.
solarize
,
kernels
=
{
datapoints
.
Image
:
F
.
solarize_image
_tensor
,
datapoints
.
Image
:
F
.
solarize_image
,
datapoints
.
Video
:
F
.
solarize_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
),
DispatcherInfo
(
F
.
autocontrast
,
kernels
=
{
datapoints
.
Image
:
F
.
autocontrast_image
_tensor
,
datapoints
.
Image
:
F
.
autocontrast_image
,
datapoints
.
Video
:
F
.
autocontrast_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
),
DispatcherInfo
(
F
.
adjust_sharpness
,
kernels
=
{
datapoints
.
Image
:
F
.
adjust_sharpness_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_sharpness_image
,
datapoints
.
Video
:
F
.
adjust_sharpness_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
),
DispatcherInfo
(
F
.
erase
,
kernels
=
{
datapoints
.
Image
:
F
.
erase_image
_tensor
,
datapoints
.
Image
:
F
.
erase_image
,
datapoints
.
Video
:
F
.
erase_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
erase_image_pil
),
test_marks
=
[
skip_dispatch_datapoint
,
],
...
...
@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
adjust_contrast
,
kernels
=
{
datapoints
.
Image
:
F
.
adjust_contrast_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_contrast_image
,
datapoints
.
Video
:
F
.
adjust_contrast_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
),
DispatcherInfo
(
F
.
adjust_gamma
,
kernels
=
{
datapoints
.
Image
:
F
.
adjust_gamma_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_gamma_image
,
datapoints
.
Video
:
F
.
adjust_gamma_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
),
DispatcherInfo
(
F
.
adjust_hue
,
kernels
=
{
datapoints
.
Image
:
F
.
adjust_hue_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_hue_image
,
datapoints
.
Video
:
F
.
adjust_hue_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
),
DispatcherInfo
(
F
.
adjust_saturation
,
kernels
=
{
datapoints
.
Image
:
F
.
adjust_saturation_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_saturation_image
,
datapoints
.
Video
:
F
.
adjust_saturation_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
),
DispatcherInfo
(
F
.
five_crop
,
kernels
=
{
datapoints
.
Image
:
F
.
five_crop_image
_tensor
,
datapoints
.
Image
:
F
.
five_crop_image
,
datapoints
.
Video
:
F
.
five_crop_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
five_crop_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
*
multi_crop_skips
,
...
...
@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
F
.
ten_crop
,
kernels
=
{
datapoints
.
Image
:
F
.
ten_crop_image
_tensor
,
datapoints
.
Image
:
F
.
ten_crop_image
,
datapoints
.
Video
:
F
.
ten_crop_video
,
},
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
*
multi_crop_skips
,
],
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
ten_crop_image_pil
),
),
DispatcherInfo
(
F
.
normalize
,
kernels
=
{
datapoints
.
Image
:
F
.
normalize_image
_tensor
,
datapoints
.
Image
:
F
.
normalize_image
,
datapoints
.
Video
:
F
.
normalize_video
,
},
test_marks
=
[
...
...
test/transforms_v2_kernel_infos.py
View file @
ca012d39
...
...
@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel):
f
"Can only test single tensor images against PIL, but input has shape
{
input_tensor
.
shape
}
"
)
input_pil
=
F
.
to_image
_pil
(
input_tensor
)
input_pil
=
F
.
to_
pil_
image
(
input_tensor
)
output_pil
=
pil_kernel
(
input_pil
,
*
other_args
,
**
kwargs
)
if
not
isinstance
(
output_pil
,
PIL
.
Image
.
Image
):
return
output_pil
output_tensor
=
F
.
to_image
_tensor
(
output_pil
)
output_tensor
=
F
.
to_image
(
output_pil
)
# 2D mask shenanigans
if
output_tensor
.
ndim
==
2
and
input_tensor
.
ndim
==
3
:
...
...
@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
crop_image
_tensor
,
F
.
crop_image
,
kernel_name
=
"crop_image_tensor"
,
sample_inputs_fn
=
sample_inputs_crop_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
crop_image_pil
),
reference_inputs_fn
=
reference_inputs_crop_image_tensor
,
float32_vs_uint8
=
True
,
),
...
...
@@ -347,7 +347,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
F
.
crop_mask
,
sample_inputs_fn
=
sample_inputs_crop_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
crop_image_pil
),
reference_inputs_fn
=
reference_inputs_crop_mask
,
float32_vs_uint8
=
True
,
),
...
...
@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
F
.
InterpolationMode
.
BICUBIC
,
}:
raise
pytest
.
UsageError
(
"Anti-aliasing is always active in PIL"
)
return
F
.
resized_crop_image_pil
(
*
args
,
**
kwargs
)
return
F
.
_
resized_crop_image_pil
(
*
args
,
**
kwargs
)
def
reference_inputs_resized_crop_image_tensor
():
...
...
@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
resized_crop_image
_tensor
,
F
.
resized_crop_image
,
sample_inputs_fn
=
sample_inputs_resized_crop_image_tensor
,
reference_fn
=
reference_resized_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_resized_crop_image_tensor
,
...
...
@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs):
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
pad_image
_tensor
,
F
.
pad_image
,
sample_inputs_fn
=
sample_inputs_pad_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
pad_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
pad_image_pil
),
reference_inputs_fn
=
reference_inputs_pad_image_tensor
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
...
...
@@ -595,7 +595,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
F
.
pad_mask
,
sample_inputs_fn
=
sample_inputs_pad_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
pad_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
pad_image_pil
),
reference_inputs_fn
=
reference_inputs_pad_mask
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
),
...
...
@@ -690,9 +690,9 @@ def sample_inputs_perspective_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
perspective_image
_tensor
,
F
.
perspective_image
,
sample_inputs_fn
=
sample_inputs_perspective_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
perspective_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
perspective_image_pil
),
reference_inputs_fn
=
reference_inputs_perspective_image_tensor
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
closeness_kwargs
=
{
...
...
@@ -715,7 +715,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
F
.
perspective_mask
,
sample_inputs_fn
=
sample_inputs_perspective_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
perspective_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
perspective_image_pil
),
reference_inputs_fn
=
reference_inputs_perspective_mask
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
...
...
@@ -786,7 +786,7 @@ def sample_inputs_elastic_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
elastic_image
_tensor
,
F
.
elastic_image
,
sample_inputs_fn
=
sample_inputs_elastic_image_tensor
,
reference_inputs_fn
=
reference_inputs_elastic_image_tensor
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
...
...
@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
center_crop_image
_tensor
,
F
.
center_crop_image
,
sample_inputs_fn
=
sample_inputs_center_crop_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
center_crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
center_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_center_crop_image_tensor
,
float32_vs_uint8
=
True
,
test_marks
=
[
...
...
@@ -889,7 +889,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
F
.
center_crop_mask
,
sample_inputs_fn
=
sample_inputs_center_crop_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
center_crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
center_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_center_crop_mask
,
float32_vs_uint8
=
True
,
test_marks
=
[
...
...
@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
gaussian_blur_image
_tensor
,
F
.
gaussian_blur_image
,
sample_inputs_fn
=
sample_inputs_gaussian_blur_image_tensor
,
closeness_kwargs
=
cuda_vs_cpu_pixel_difference
(),
test_marks
=
[
...
...
@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
equalize_image
_tensor
,
F
.
equalize_image
,
kernel_name
=
"equalize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_equalize_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
equalize_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
equalize_image_pil
),
float32_vs_uint8
=
True
,
reference_inputs_fn
=
reference_inputs_equalize_image_tensor
,
),
...
...
@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
invert_image
_tensor
,
F
.
invert_image
,
kernel_name
=
"invert_image_tensor"
,
sample_inputs_fn
=
sample_inputs_invert_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
invert_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
invert_image_pil
),
reference_inputs_fn
=
reference_inputs_invert_image_tensor
,
float32_vs_uint8
=
True
,
),
...
...
@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
posterize_image
_tensor
,
F
.
posterize_image
,
kernel_name
=
"posterize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_posterize_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
posterize_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
posterize_image_pil
),
reference_inputs_fn
=
reference_inputs_posterize_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
...
...
@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
solarize_image
_tensor
,
F
.
solarize_image
,
kernel_name
=
"solarize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_solarize_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
solarize_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
solarize_image_pil
),
reference_inputs_fn
=
reference_inputs_solarize_image_tensor
,
float32_vs_uint8
=
uint8_to_float32_threshold_adapter
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
...
...
@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
autocontrast_image
_tensor
,
F
.
autocontrast_image
,
kernel_name
=
"autocontrast_image_tensor"
,
sample_inputs_fn
=
sample_inputs_autocontrast_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
autocontrast_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
autocontrast_image_pil
),
reference_inputs_fn
=
reference_inputs_autocontrast_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
...
...
@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
adjust_sharpness_image
_tensor
,
F
.
adjust_sharpness_image
,
kernel_name
=
"adjust_sharpness_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_sharpness_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_sharpness_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_sharpness_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_sharpness_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(
2
),
...
...
@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
erase_image
_tensor
,
F
.
erase_image
,
kernel_name
=
"erase_image_tensor"
,
sample_inputs_fn
=
sample_inputs_erase_image_tensor
,
),
...
...
@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
adjust_contrast_image
_tensor
,
F
.
adjust_contrast_image
,
kernel_name
=
"adjust_contrast_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_contrast_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_contrast_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_contrast_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_contrast_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
...
...
@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
adjust_gamma_image
_tensor
,
F
.
adjust_gamma_image
,
kernel_name
=
"adjust_gamma_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_gamma_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_gamma_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_gamma_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_gamma_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
...
...
@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
adjust_hue_image
_tensor
,
F
.
adjust_hue_image
,
kernel_name
=
"adjust_hue_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_hue_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_hue_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_hue_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_hue_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
...
...
@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
adjust_saturation_image
_tensor
,
F
.
adjust_saturation_image
,
kernel_name
=
"adjust_saturation_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_saturation_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_saturation_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_saturation_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_saturation_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
...
...
@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
def
wrapper
(
input_tensor
,
*
other_args
,
**
kwargs
):
output
=
pil_reference_wrapper
(
pil_kernel
)(
input_tensor
,
*
other_args
,
**
kwargs
)
return
type
(
output
)(
F
.
to_dtype_image_tensor
(
F
.
to_image_tensor
(
output_pil
),
dtype
=
input_tensor
.
dtype
,
scale
=
True
)
for
output_pil
in
output
F
.
to_dtype_image
(
F
.
to_image
(
output_pil
),
dtype
=
input_tensor
.
dtype
,
scale
=
True
)
for
output_pil
in
output
)
return
wrapper
...
...
@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
five_crop_image
_tensor
,
F
.
five_crop_image
,
sample_inputs_fn
=
sample_inputs_five_crop_image_tensor
,
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
five_crop_image_pil
),
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
_
five_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_five_crop_image_tensor
,
test_marks
=
_common_five_ten_crop_marks
,
),
...
...
@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend(
test_marks
=
_common_five_ten_crop_marks
,
),
KernelInfo
(
F
.
ten_crop_image
_tensor
,
F
.
ten_crop_image
,
sample_inputs_fn
=
sample_inputs_ten_crop_image_tensor
,
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
ten_crop_image_pil
),
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
_
ten_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_ten_crop_image_tensor
,
test_marks
=
_common_five_ten_crop_marks
,
),
...
...
@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video():
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
normalize_image
_tensor
,
F
.
normalize_image
,
kernel_name
=
"normalize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_normalize_image_tensor
,
reference_fn
=
reference_normalize_image_tensor
,
...
...
torchvision/prototype/transforms/_augment.py
View file @
ca012d39
...
...
@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform):
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_simple_tensor
(
obj
):
images
.
append
(
obj
)
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
images
.
append
(
F
.
to_image
_tensor
(
obj
))
images
.
append
(
F
.
to_image
(
obj
))
elif
isinstance
(
obj
,
datapoints
.
BoundingBoxes
):
bboxes
.
append
(
obj
)
elif
isinstance
(
obj
,
datapoints
.
Mask
):
...
...
@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform):
flat_sample
[
i
]
=
datapoints
.
wrap
(
output_images
[
c0
],
like
=
obj
)
c0
+=
1
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
flat_sample
[
i
]
=
F
.
to_image
_pil
(
output_images
[
c0
])
flat_sample
[
i
]
=
F
.
to_
pil_
image
(
output_images
[
c0
])
c0
+=
1
elif
is_simple_tensor
(
obj
):
flat_sample
[
i
]
=
output_images
[
c0
]
...
...
torchvision/transforms/v2/__init__.py
View file @
ca012d39
...
...
@@ -52,7 +52,7 @@ from ._misc import (
ToDtype
,
)
from
._temporal
import
UniformTemporalSubsample
from
._type_conversion
import
PILToTensor
,
ToImage
PIL
,
ToImageTensor
,
ToPILImage
from
._type_conversion
import
PILToTensor
,
ToImage
,
ToPILImage
from
._deprecated
import
ToTensor
# usort: skip
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
ca012d39
...
...
@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase):
if
isinstance
(
orig_image_or_video
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
mix
=
datapoints
.
wrap
(
mix
,
like
=
orig_image_or_video
)
elif
isinstance
(
orig_image_or_video
,
PIL
.
Image
.
Image
):
mix
=
F
.
to_image
_pil
(
mix
)
mix
=
F
.
to_
pil_
image
(
mix
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
mix
)
torchvision/transforms/v2/_type_conversion.py
View file @
ca012d39
...
...
@@ -26,7 +26,7 @@ class PILToTensor(Transform):
return
F
.
pil_to_tensor
(
inpt
)
class
ToImage
Tensor
(
Transform
):
class
ToImage
(
Transform
):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
; this does not scale values.
...
...
@@ -40,10 +40,10 @@ class ToImageTensor(Transform):
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
datapoints
.
Image
:
return
F
.
to_image
_tensor
(
inpt
)
return
F
.
to_image
(
inpt
)
class
ToImage
PIL
(
Transform
):
class
To
PIL
Image
(
Transform
):
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
.. v2betastatus:: ToImagePIL transform
...
...
@@ -74,9 +74,4 @@ class ToImagePIL(Transform):
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
PIL
.
Image
.
Image
:
return
F
.
to_image_pil
(
inpt
,
mode
=
self
.
mode
)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage
=
ToImagePIL
return
F
.
to_pil_image
(
inpt
,
mode
=
self
.
mode
)
torchvision/transforms/v2/functional/__init__.py
View file @
ca012d39
...
...
@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip
from
._meta
import
(
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_dimensions_image
_tensor
,
get_dimensions_image_pil
,
get_dimensions_image
,
_
get_dimensions_image_pil
,
get_dimensions_video
,
get_dimensions
,
get_num_frames_video
,
get_num_frames
,
get_image_num_channels
,
get_num_channels_image
_tensor
,
get_num_channels_image_pil
,
get_num_channels_image
,
_
get_num_channels_image_pil
,
get_num_channels_video
,
get_num_channels
,
get_size_bounding_boxes
,
get_size_image
_tensor
,
get_size_image_pil
,
get_size_image
,
_
get_size_image_pil
,
get_size_mask
,
get_size_video
,
get_size
,
)
# usort: skip
from
._augment
import
erase
,
erase_image_pil
,
erase
_image_tensor
,
erase_video
from
._augment
import
_
erase_image_pil
,
erase
,
erase_image
,
erase_video
from
._color
import
(
_adjust_brightness_image_pil
,
_adjust_contrast_image_pil
,
_adjust_gamma_image_pil
,
_adjust_hue_image_pil
,
_adjust_saturation_image_pil
,
_adjust_sharpness_image_pil
,
_autocontrast_image_pil
,
_equalize_image_pil
,
_invert_image_pil
,
_permute_channels_image_pil
,
_posterize_image_pil
,
_rgb_to_grayscale_image_pil
,
_solarize_image_pil
,
adjust_brightness
,
adjust_brightness_image_pil
,
adjust_brightness_image_tensor
,
adjust_brightness_image
,
adjust_brightness_video
,
adjust_contrast
,
adjust_contrast_image_pil
,
adjust_contrast_image_tensor
,
adjust_contrast_image
,
adjust_contrast_video
,
adjust_gamma
,
adjust_gamma_image_pil
,
adjust_gamma_image_tensor
,
adjust_gamma_image
,
adjust_gamma_video
,
adjust_hue
,
adjust_hue_image_pil
,
adjust_hue_image_tensor
,
adjust_hue_image
,
adjust_hue_video
,
adjust_saturation
,
adjust_saturation_image_pil
,
adjust_saturation_image_tensor
,
adjust_saturation_image
,
adjust_saturation_video
,
adjust_sharpness
,
adjust_sharpness_image_pil
,
adjust_sharpness_image_tensor
,
adjust_sharpness_image
,
adjust_sharpness_video
,
autocontrast
,
autocontrast_image_pil
,
autocontrast_image_tensor
,
autocontrast_image
,
autocontrast_video
,
equalize
,
equalize_image_pil
,
equalize_image_tensor
,
equalize_image
,
equalize_video
,
invert
,
invert_image_pil
,
invert_image_tensor
,
invert_image
,
invert_video
,
permute_channels
,
permute_channels_image_pil
,
permute_channels_image_tensor
,
permute_channels_image
,
permute_channels_video
,
posterize
,
posterize_image_pil
,
posterize_image_tensor
,
posterize_image
,
posterize_video
,
rgb_to_grayscale
,
rgb_to_grayscale_image_pil
,
rgb_to_grayscale_image_tensor
,
rgb_to_grayscale_image
,
solarize
,
solarize_image_pil
,
solarize_image_tensor
,
solarize_image
,
solarize_video
,
to_grayscale
,
)
from
._geometry
import
(
_affine_image_pil
,
_center_crop_image_pil
,
_crop_image_pil
,
_elastic_image_pil
,
_five_crop_image_pil
,
_horizontal_flip_image_pil
,
_pad_image_pil
,
_perspective_image_pil
,
_resize_image_pil
,
_resized_crop_image_pil
,
_rotate_image_pil
,
_ten_crop_image_pil
,
_vertical_flip_image_pil
,
affine
,
affine_bounding_boxes
,
affine_image_pil
,
affine_image_tensor
,
affine_image
,
affine_mask
,
affine_video
,
center_crop
,
center_crop_bounding_boxes
,
center_crop_image_pil
,
center_crop_image_tensor
,
center_crop_image
,
center_crop_mask
,
center_crop_video
,
crop
,
crop_bounding_boxes
,
crop_image_pil
,
crop_image_tensor
,
crop_image
,
crop_mask
,
crop_video
,
elastic
,
elastic_bounding_boxes
,
elastic_image_pil
,
elastic_image_tensor
,
elastic_image
,
elastic_mask
,
elastic_transform
,
elastic_video
,
five_crop
,
five_crop_image_pil
,
five_crop_image_tensor
,
five_crop_image
,
five_crop_video
,
hflip
,
# TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip
,
horizontal_flip_bounding_boxes
,
horizontal_flip_image_pil
,
horizontal_flip_image_tensor
,
horizontal_flip_image
,
horizontal_flip_mask
,
horizontal_flip_video
,
pad
,
pad_bounding_boxes
,
pad_image_pil
,
pad_image_tensor
,
pad_image
,
pad_mask
,
pad_video
,
perspective
,
perspective_bounding_boxes
,
perspective_image_pil
,
perspective_image_tensor
,
perspective_image
,
perspective_mask
,
perspective_video
,
resize
,
resize_bounding_boxes
,
resize_image_pil
,
resize_image_tensor
,
resize_image
,
resize_mask
,
resize_video
,
resized_crop
,
resized_crop_bounding_boxes
,
resized_crop_image_pil
,
resized_crop_image_tensor
,
resized_crop_image
,
resized_crop_mask
,
resized_crop_video
,
rotate
,
rotate_bounding_boxes
,
rotate_image_pil
,
rotate_image_tensor
,
rotate_image
,
rotate_mask
,
rotate_video
,
ten_crop
,
ten_crop_image_pil
,
ten_crop_image_tensor
,
ten_crop_image
,
ten_crop_video
,
vertical_flip
,
vertical_flip_bounding_boxes
,
vertical_flip_image_pil
,
vertical_flip_image_tensor
,
vertical_flip_image
,
vertical_flip_mask
,
vertical_flip_video
,
vflip
,
)
from
._misc
import
(
_gaussian_blur_image_pil
,
convert_image_dtype
,
gaussian_blur
,
gaussian_blur_image_pil
,
gaussian_blur_image_tensor
,
gaussian_blur_image
,
gaussian_blur_video
,
normalize
,
normalize_image
_tensor
,
normalize_image
,
normalize_video
,
to_dtype
,
to_dtype_image
_tensor
,
to_dtype_image
,
to_dtype_video
,
)
from
._temporal
import
uniform_temporal_subsample
,
uniform_temporal_subsample_video
from
._type_conversion
import
pil_to_tensor
,
to_image
_pil
,
to_image_tensor
,
to_pil_image
from
._type_conversion
import
pil_to_tensor
,
to_image
,
to_pil_image
from
._deprecated
import
get_image_size
,
to_tensor
# usort: skip
torchvision/transforms/v2/functional/_augment.py
View file @
ca012d39
...
...
@@ -18,7 +18,7 @@ def erase(
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
erase_image
_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
_log_api_usage_once
(
erase
)
...
...
@@ -28,7 +28,7 @@ def erase(
@
_register_kernel_internal
(
erase
,
torch
.
Tensor
)
@
_register_kernel_internal
(
erase
,
datapoints
.
Image
)
def
erase_image
_tensor
(
def
erase_image
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
inplace
:
...
...
@@ -39,11 +39,11 @@ def erase_image_tensor(
@
_register_kernel_internal
(
erase
,
PIL
.
Image
.
Image
)
def
erase_image_pil
(
def
_
erase_image_pil
(
image
:
PIL
.
Image
.
Image
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
output
=
erase_image
_tensor
(
t_img
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
output
=
erase_image
(
t_img
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
...
...
@@ -51,4 +51,4 @@ def erase_image_pil(
def
erase_video
(
video
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
erase_image
_tensor
(
video
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image
(
video
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
torchvision/transforms/v2/functional/_color.py
View file @
ca012d39
...
...
@@ -9,14 +9,14 @@ from torchvision.transforms._functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
._misc
import
_num_value_bits
,
to_dtype_image
_tensor
from
._type_conversion
import
pil_to_tensor
,
to_image
_pil
from
._misc
import
_num_value_bits
,
to_dtype_image
from
._type_conversion
import
pil_to_tensor
,
to_
pil_
image
from
._utils
import
_get_kernel
,
_register_kernel_internal
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image
_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
return
rgb_to_grayscale_image
(
inpt
,
num_output_channels
=
num_output_channels
)
_log_api_usage_once
(
rgb_to_grayscale
)
...
...
@@ -29,7 +29,7 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.
to_grayscale
=
rgb_to_grayscale
def
_rgb_to_grayscale_image
_tensor
(
def
_rgb_to_grayscale_image
(
image
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
,
preserve_dtype
:
bool
=
True
)
->
torch
.
Tensor
:
if
image
.
shape
[
-
3
]
==
1
:
...
...
@@ -47,14 +47,14 @@ def _rgb_to_grayscale_image_tensor(
@
_register_kernel_internal
(
rgb_to_grayscale
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rgb_to_grayscale
,
datapoints
.
Image
)
def
rgb_to_grayscale_image
_tensor
(
image
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
def
rgb_to_grayscale_image
(
image
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
return
_rgb_to_grayscale_image
_tensor
(
image
,
num_output_channels
=
num_output_channels
,
preserve_dtype
=
True
)
return
_rgb_to_grayscale_image
(
image
,
num_output_channels
=
num_output_channels
,
preserve_dtype
=
True
)
@
_register_kernel_internal
(
rgb_to_grayscale
,
PIL
.
Image
.
Image
)
def
rgb_to_grayscale_image_pil
(
image
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
def
_
rgb_to_grayscale_image_pil
(
image
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
return
_FP
.
to_grayscale
(
image
,
num_output_channels
=
num_output_channels
)
...
...
@@ -71,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image
_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image
(
inpt
,
brightness_factor
=
brightness_factor
)
_log_api_usage_once
(
adjust_brightness
)
...
...
@@ -81,7 +81,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@
_register_kernel_internal
(
adjust_brightness
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoints
.
Image
)
def
adjust_brightness_image
_tensor
(
image
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_brightness_image
(
image
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
brightness_factor
<
0
:
raise
ValueError
(
f
"brightness_factor (
{
brightness_factor
}
) is not non-negative."
)
...
...
@@ -96,18 +96,18 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
@
_register_kernel_internal
(
adjust_brightness
,
PIL
.
Image
.
Image
)
def
adjust_brightness_image_pil
(
image
:
PIL
.
Image
.
Image
,
brightness_factor
:
float
)
->
PIL
.
Image
.
Image
:
def
_
adjust_brightness_image_pil
(
image
:
PIL
.
Image
.
Image
,
brightness_factor
:
float
)
->
PIL
.
Image
.
Image
:
return
_FP
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoints
.
Video
)
def
adjust_brightness_video
(
video
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_brightness_image
_tensor
(
video
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image
(
video
,
brightness_factor
=
brightness_factor
)
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image
_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
return
adjust_saturation_image
(
inpt
,
saturation_factor
=
saturation_factor
)
_log_api_usage_once
(
adjust_saturation
)
...
...
@@ -117,7 +117,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@
_register_kernel_internal
(
adjust_saturation
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoints
.
Image
)
def
adjust_saturation_image
_tensor
(
image
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_saturation_image
(
image
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
saturation_factor
<
0
:
raise
ValueError
(
f
"saturation_factor (
{
saturation_factor
}
) is not non-negative."
)
...
...
@@ -128,24 +128,24 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if
c
==
1
:
# Match PIL behaviour
return
image
grayscale_image
=
_rgb_to_grayscale_image
_tensor
(
image
,
num_output_channels
=
1
,
preserve_dtype
=
False
)
grayscale_image
=
_rgb_to_grayscale_image
(
image
,
num_output_channels
=
1
,
preserve_dtype
=
False
)
if
not
image
.
is_floating_point
():
grayscale_image
=
grayscale_image
.
floor_
()
return
_blend
(
image
,
grayscale_image
,
saturation_factor
)
adjust_saturation_image_pil
=
_register_kernel_internal
(
adjust_saturation
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_saturation
)
_
adjust_saturation_image_pil
=
_register_kernel_internal
(
adjust_saturation
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_saturation
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoints
.
Video
)
def
adjust_saturation_video
(
video
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_saturation_image
_tensor
(
video
,
saturation_factor
=
saturation_factor
)
return
adjust_saturation_image
(
video
,
saturation_factor
=
saturation_factor
)
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image
_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image
(
inpt
,
contrast_factor
=
contrast_factor
)
_log_api_usage_once
(
adjust_contrast
)
...
...
@@ -155,7 +155,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@
_register_kernel_internal
(
adjust_contrast
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoints
.
Image
)
def
adjust_contrast_image
_tensor
(
image
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_contrast_image
(
image
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
contrast_factor
<
0
:
raise
ValueError
(
f
"contrast_factor (
{
contrast_factor
}
) is not non-negative."
)
...
...
@@ -164,7 +164,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
raise
TypeError
(
f
"Input image tensor permitted channel values are 1 or 3, but found
{
c
}
"
)
fp
=
image
.
is_floating_point
()
if
c
==
3
:
grayscale_image
=
_rgb_to_grayscale_image
_tensor
(
image
,
num_output_channels
=
1
,
preserve_dtype
=
False
)
grayscale_image
=
_rgb_to_grayscale_image
(
image
,
num_output_channels
=
1
,
preserve_dtype
=
False
)
if
not
fp
:
grayscale_image
=
grayscale_image
.
floor_
()
else
:
...
...
@@ -173,17 +173,17 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
return
_blend
(
image
,
mean
,
contrast_factor
)
adjust_contrast_image_pil
=
_register_kernel_internal
(
adjust_contrast
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_contrast
)
_
adjust_contrast_image_pil
=
_register_kernel_internal
(
adjust_contrast
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_contrast
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoints
.
Video
)
def
adjust_contrast_video
(
video
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_contrast_image
_tensor
(
video
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image
(
video
,
contrast_factor
=
contrast_factor
)
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image
_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
return
adjust_sharpness_image
(
inpt
,
sharpness_factor
=
sharpness_factor
)
_log_api_usage_once
(
adjust_sharpness
)
...
...
@@ -193,7 +193,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@
_register_kernel_internal
(
adjust_sharpness
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoints
.
Image
)
def
adjust_sharpness_image
_tensor
(
image
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_sharpness_image
(
image
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
num_channels
,
height
,
width
=
image
.
shape
[
-
3
:]
if
num_channels
not
in
(
1
,
3
):
raise
TypeError
(
f
"Input image tensor can have 1 or 3 channels, but found
{
num_channels
}
"
)
...
...
@@ -245,17 +245,17 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
return
output
adjust_sharpness_image_pil
=
_register_kernel_internal
(
adjust_sharpness
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_sharpness
)
_
adjust_sharpness_image_pil
=
_register_kernel_internal
(
adjust_sharpness
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_sharpness
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoints
.
Video
)
def
adjust_sharpness_video
(
video
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_sharpness_image
_tensor
(
video
,
sharpness_factor
=
sharpness_factor
)
return
adjust_sharpness_image
(
video
,
sharpness_factor
=
sharpness_factor
)
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image
_tensor
(
inpt
,
hue_factor
=
hue_factor
)
return
adjust_hue_image
(
inpt
,
hue_factor
=
hue_factor
)
_log_api_usage_once
(
adjust_hue
)
...
...
@@ -335,7 +335,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
adjust_hue
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_hue
,
datapoints
.
Image
)
def
adjust_hue_image
_tensor
(
image
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_hue_image
(
image
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
...
...
@@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return
image
orig_dtype
=
image
.
dtype
image
=
to_dtype_image
_tensor
(
image
,
torch
.
float32
,
scale
=
True
)
image
=
to_dtype_image
(
image
,
torch
.
float32
,
scale
=
True
)
image
=
_rgb_to_hsv
(
image
)
h
,
s
,
v
=
image
.
unbind
(
dim
=-
3
)
...
...
@@ -359,20 +359,20 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
image
=
torch
.
stack
((
h
,
s
,
v
),
dim
=-
3
)
image_hue_adj
=
_hsv_to_rgb
(
image
)
return
to_dtype_image
_tensor
(
image_hue_adj
,
orig_dtype
,
scale
=
True
)
return
to_dtype_image
(
image_hue_adj
,
orig_dtype
,
scale
=
True
)
adjust_hue_image_pil
=
_register_kernel_internal
(
adjust_hue
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_hue
)
_
adjust_hue_image_pil
=
_register_kernel_internal
(
adjust_hue
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_hue
)
@
_register_kernel_internal
(
adjust_hue
,
datapoints
.
Video
)
def
adjust_hue_video
(
video
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_hue_image
_tensor
(
video
,
hue_factor
=
hue_factor
)
return
adjust_hue_image
(
video
,
hue_factor
=
hue_factor
)
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image
_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
_log_api_usage_once
(
adjust_gamma
)
...
...
@@ -382,14 +382,14 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@
_register_kernel_internal
(
adjust_gamma
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoints
.
Image
)
def
adjust_gamma_image
_tensor
(
image
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1.0
)
->
torch
.
Tensor
:
def
adjust_gamma_image
(
image
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1.0
)
->
torch
.
Tensor
:
if
gamma
<
0
:
raise
ValueError
(
"Gamma should be a non-negative real number"
)
# The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer).
# Since the gamma is non-negative, the output remains at [0, 1] scale.
if
not
torch
.
is_floating_point
(
image
):
output
=
to_dtype_image
_tensor
(
image
,
torch
.
float32
,
scale
=
True
).
pow_
(
gamma
)
output
=
to_dtype_image
(
image
,
torch
.
float32
,
scale
=
True
).
pow_
(
gamma
)
else
:
output
=
image
.
pow
(
gamma
)
...
...
@@ -398,20 +398,20 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
# of the output can go beyond [0, 1].
output
=
output
.
mul_
(
gain
).
clamp_
(
0.0
,
1.0
)
return
to_dtype_image
_tensor
(
output
,
image
.
dtype
,
scale
=
True
)
return
to_dtype_image
(
output
,
image
.
dtype
,
scale
=
True
)
adjust_gamma_image_pil
=
_register_kernel_internal
(
adjust_gamma
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_gamma
)
_
adjust_gamma_image_pil
=
_register_kernel_internal
(
adjust_gamma
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_gamma
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoints
.
Video
)
def
adjust_gamma_video
(
video
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
return
adjust_gamma_image
_tensor
(
video
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image
(
video
,
gamma
=
gamma
,
gain
=
gain
)
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
posterize_image
_tensor
(
inpt
,
bits
=
bits
)
return
posterize_image
(
inpt
,
bits
=
bits
)
_log_api_usage_once
(
posterize
)
...
...
@@ -421,7 +421,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@
_register_kernel_internal
(
posterize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
posterize
,
datapoints
.
Image
)
def
posterize_image
_tensor
(
image
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
def
posterize_image
(
image
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
levels
=
1
<<
bits
return
image
.
mul
(
levels
).
floor_
().
clamp_
(
0
,
levels
-
1
).
mul_
(
1.0
/
levels
)
...
...
@@ -434,17 +434,17 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
return
image
&
mask
posterize_image_pil
=
_register_kernel_internal
(
posterize
,
PIL
.
Image
.
Image
)(
_FP
.
posterize
)
_
posterize_image_pil
=
_register_kernel_internal
(
posterize
,
PIL
.
Image
.
Image
)(
_FP
.
posterize
)
@
_register_kernel_internal
(
posterize
,
datapoints
.
Video
)
def
posterize_video
(
video
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
return
posterize_image
_tensor
(
video
,
bits
=
bits
)
return
posterize_image
(
video
,
bits
=
bits
)
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
solarize_image
_tensor
(
inpt
,
threshold
=
threshold
)
return
solarize_image
(
inpt
,
threshold
=
threshold
)
_log_api_usage_once
(
solarize
)
...
...
@@ -454,24 +454,24 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@
_register_kernel_internal
(
solarize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
solarize
,
datapoints
.
Image
)
def
solarize_image
_tensor
(
image
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
def
solarize_image
(
image
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
threshold
>
_max_value
(
image
.
dtype
):
raise
TypeError
(
f
"Threshold should be less or equal the maximum value of the dtype, but got
{
threshold
}
"
)
return
torch
.
where
(
image
>=
threshold
,
invert_image
_tensor
(
image
),
image
)
return
torch
.
where
(
image
>=
threshold
,
invert_image
(
image
),
image
)
solarize_image_pil
=
_register_kernel_internal
(
solarize
,
PIL
.
Image
.
Image
)(
_FP
.
solarize
)
_
solarize_image_pil
=
_register_kernel_internal
(
solarize
,
PIL
.
Image
.
Image
)(
_FP
.
solarize
)
@
_register_kernel_internal
(
solarize
,
datapoints
.
Video
)
def
solarize_video
(
video
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
return
solarize_image
_tensor
(
video
,
threshold
=
threshold
)
return
solarize_image
(
video
,
threshold
=
threshold
)
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image
_tensor
(
inpt
)
return
autocontrast_image
(
inpt
)
_log_api_usage_once
(
autocontrast
)
...
...
@@ -481,7 +481,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
autocontrast
,
torch
.
Tensor
)
@
_register_kernel_internal
(
autocontrast
,
datapoints
.
Image
)
def
autocontrast_image
_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
autocontrast_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
c
=
image
.
shape
[
-
3
]
if
c
not
in
[
1
,
3
]:
raise
TypeError
(
f
"Input image tensor permitted channel values are 1 or 3, but found
{
c
}
"
)
...
...
@@ -510,17 +510,17 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
return
diff
.
div_
(
inv_scale
).
clamp_
(
0
,
bound
).
to
(
image
.
dtype
)
autocontrast_image_pil
=
_register_kernel_internal
(
autocontrast
,
PIL
.
Image
.
Image
)(
_FP
.
autocontrast
)
_
autocontrast_image_pil
=
_register_kernel_internal
(
autocontrast
,
PIL
.
Image
.
Image
)(
_FP
.
autocontrast
)
@
_register_kernel_internal
(
autocontrast
,
datapoints
.
Video
)
def
autocontrast_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
autocontrast_image
_tensor
(
video
)
return
autocontrast_image
(
video
)
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
equalize_image
_tensor
(
inpt
)
return
equalize_image
(
inpt
)
_log_api_usage_once
(
equalize
)
...
...
@@ -530,7 +530,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
equalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
equalize
,
datapoints
.
Image
)
def
equalize_image
_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
equalize_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
numel
()
==
0
:
return
image
...
...
@@ -545,7 +545,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype
=
image
.
dtype
image
=
to_dtype_image
_tensor
(
image
,
torch
.
uint8
,
scale
=
True
)
image
=
to_dtype_image
(
image
,
torch
.
uint8
,
scale
=
True
)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
...
...
@@ -596,20 +596,20 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalized_image
=
lut
.
gather
(
dim
=-
1
,
index
=
flat_image
).
view_as
(
image
)
output
=
torch
.
where
(
valid_equalization
,
equalized_image
,
image
)
return
to_dtype_image
_tensor
(
output
,
output_dtype
,
scale
=
True
)
return
to_dtype_image
(
output
,
output_dtype
,
scale
=
True
)
equalize_image_pil
=
_register_kernel_internal
(
equalize
,
PIL
.
Image
.
Image
)(
_FP
.
equalize
)
_
equalize_image_pil
=
_register_kernel_internal
(
equalize
,
PIL
.
Image
.
Image
)(
_FP
.
equalize
)
@
_register_kernel_internal
(
equalize
,
datapoints
.
Video
)
def
equalize_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
equalize_image
_tensor
(
video
)
return
equalize_image
(
video
)
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
invert_image
_tensor
(
inpt
)
return
invert_image
(
inpt
)
_log_api_usage_once
(
invert
)
...
...
@@ -619,7 +619,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
invert
,
torch
.
Tensor
)
@
_register_kernel_internal
(
invert
,
datapoints
.
Image
)
def
invert_image
_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
invert_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
return
1.0
-
image
elif
image
.
dtype
==
torch
.
uint8
:
...
...
@@ -629,12 +629,12 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
return
image
.
bitwise_xor
((
1
<<
_num_value_bits
(
image
.
dtype
))
-
1
)
invert_image_pil
=
_register_kernel_internal
(
invert
,
PIL
.
Image
.
Image
)(
_FP
.
invert
)
_
invert_image_pil
=
_register_kernel_internal
(
invert
,
PIL
.
Image
.
Image
)(
_FP
.
invert
)
@
_register_kernel_internal
(
invert
,
datapoints
.
Video
)
def
invert_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
invert_image
_tensor
(
video
)
return
invert_image
(
video
)
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
...
...
@@ -660,7 +660,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
"""
if
torch
.
jit
.
is_scripting
():
return
permute_channels_image
_tensor
(
inpt
,
permutation
=
permutation
)
return
permute_channels_image
(
inpt
,
permutation
=
permutation
)
_log_api_usage_once
(
permute_channels
)
...
...
@@ -670,7 +670,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@
_register_kernel_internal
(
permute_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
permute_channels
,
datapoints
.
Image
)
def
permute_channels_image
_tensor
(
image
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
def
permute_channels_image
(
image
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
shape
=
image
.
shape
num_channels
,
height
,
width
=
shape
[
-
3
:]
...
...
@@ -688,10 +688,10 @@ def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -
@
_register_kernel_internal
(
permute_channels
,
PIL
.
Image
.
Image
)
def
permute_channels_image_pil
(
image
:
PIL
.
Image
.
Image
,
permutation
:
List
[
int
])
->
PIL
.
Image
:
return
to_image
_pil
(
permute_channels_image
_tensor
(
pil_to_tensor
(
image
),
permutation
=
permutation
))
def
_
permute_channels_image_pil
(
image
:
PIL
.
Image
.
Image
,
permutation
:
List
[
int
])
->
PIL
.
Image
:
return
to_
pil_
image
(
permute_channels_image
(
pil_to_tensor
(
image
),
permutation
=
permutation
))
@
_register_kernel_internal
(
permute_channels
,
datapoints
.
Video
)
def
permute_channels_video
(
video
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
return
permute_channels_image
_tensor
(
video
,
permutation
=
permutation
)
return
permute_channels_image
(
video
,
permutation
=
permutation
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment