Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ca012d39
"src/vscode:/vscode.git/clone" did not exist on "048d9019931ddf2b28c4a909bf6c2d681a809159"
Unverified
Commit
ca012d39
authored
Aug 16, 2023
by
Philip Meier
Committed by
GitHub
Aug 16, 2023
Browse files
make PIL kernels private (#7831)
parent
cdbbd666
Changes
25
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
329 additions
and
355 deletions
+329
-355
docs/source/transforms.rst
docs/source/transforms.rst
+1
-2
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+2
-2
references/detection/presets.py
references/detection/presets.py
+4
-4
references/segmentation/presets.py
references/segmentation/presets.py
+3
-3
test/common_utils.py
test/common_utils.py
+5
-5
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+2
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+6
-23
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+8
-8
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+14
-14
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+54
-56
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+2
-2
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+41
-41
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+45
-46
torchvision/prototype/transforms/_augment.py
torchvision/prototype/transforms/_augment.py
+2
-2
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-1
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+1
-1
torchvision/transforms/v2/_type_conversion.py
torchvision/transforms/v2/_type_conversion.py
+4
-9
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+64
-64
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+5
-5
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+65
-65
No files found.
docs/source/transforms.rst
View file @
ca012d39
...
@@ -228,12 +228,11 @@ Conversion
...
@@ -228,12 +228,11 @@ Conversion
ToPILImage
ToPILImage
v2.ToPILImage
v2.ToPILImage
v2.ToImagePIL
ToTensor
ToTensor
v2.ToTensor
v2.ToTensor
PILToTensor
PILToTensor
v2.PILToTensor
v2.PILToTensor
v2.ToImage
Tensor
v2.ToImage
ConvertImageDtype
ConvertImageDtype
v2.ConvertImageDtype
v2.ConvertImageDtype
v2.ToDtype
v2.ToDtype
...
...
gallery/plot_transforms_v2_e2e.py
View file @
ca012d39
...
@@ -27,7 +27,7 @@ def show(sample):
...
@@ -27,7 +27,7 @@ def show(sample):
image
,
target
=
sample
image
,
target
=
sample
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
image
=
F
.
to_image
_tensor
(
image
)
image
=
F
.
to_image
(
image
)
image
=
F
.
to_dtype
(
image
,
torch
.
uint8
,
scale
=
True
)
image
=
F
.
to_dtype
(
image
,
torch
.
uint8
,
scale
=
True
)
annotated_image
=
draw_bounding_boxes
(
image
,
target
[
"boxes"
],
colors
=
"yellow"
,
width
=
3
)
annotated_image
=
draw_bounding_boxes
(
image
,
target
[
"boxes"
],
colors
=
"yellow"
,
width
=
3
)
...
@@ -101,7 +101,7 @@ transform = transforms.Compose(
...
@@ -101,7 +101,7 @@ transform = transforms.Compose(
transforms
.
RandomZoomOut
(
fill
=
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}),
transforms
.
RandomZoomOut
(
fill
=
{
PIL
.
Image
.
Image
:
(
123
,
117
,
104
),
"others"
:
0
}),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToImage
Tensor
(),
transforms
.
ToImage
(),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
transforms
.
SanitizeBoundingBoxes
(),
transforms
.
SanitizeBoundingBoxes
(),
]
]
...
...
references/detection/presets.py
View file @
ca012d39
...
@@ -33,7 +33,7 @@ class DetectionPresetTrain:
...
@@ -33,7 +33,7 @@ class DetectionPresetTrain:
transforms
=
[]
transforms
=
[]
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"datapoint"
:
if
backend
==
"datapoint"
:
transforms
.
append
(
T
.
ToImage
Tensor
())
transforms
.
append
(
T
.
ToImage
())
elif
backend
==
"tensor"
:
elif
backend
==
"tensor"
:
transforms
.
append
(
T
.
PILToTensor
())
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
...
@@ -71,7 +71,7 @@ class DetectionPresetTrain:
...
@@ -71,7 +71,7 @@ class DetectionPresetTrain:
if
backend
==
"pil"
:
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2.
# Note: we could just convert to pure tensors even in v2.
transforms
+=
[
T
.
ToImage
Tensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ToImage
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
)]
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
)]
...
@@ -94,11 +94,11 @@ class DetectionPresetEval:
...
@@ -94,11 +94,11 @@ class DetectionPresetEval:
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"pil"
:
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2?
# Note: we could just convert to pure tensors even in v2?
transforms
+=
[
T
.
ToImage
Tensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ToImage
()
if
use_v2
else
T
.
PILToTensor
()]
elif
backend
==
"tensor"
:
elif
backend
==
"tensor"
:
transforms
+=
[
T
.
PILToTensor
()]
transforms
+=
[
T
.
PILToTensor
()]
elif
backend
==
"datapoint"
:
elif
backend
==
"datapoint"
:
transforms
+=
[
T
.
ToImage
Tensor
()]
transforms
+=
[
T
.
ToImage
()]
else
:
else
:
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
...
...
references/segmentation/presets.py
View file @
ca012d39
...
@@ -32,7 +32,7 @@ class SegmentationPresetTrain:
...
@@ -32,7 +32,7 @@ class SegmentationPresetTrain:
transforms
=
[]
transforms
=
[]
backend
=
backend
.
lower
()
backend
=
backend
.
lower
()
if
backend
==
"datapoint"
:
if
backend
==
"datapoint"
:
transforms
.
append
(
T
.
ToImage
Tensor
())
transforms
.
append
(
T
.
ToImage
())
elif
backend
==
"tensor"
:
elif
backend
==
"tensor"
:
transforms
.
append
(
T
.
PILToTensor
())
transforms
.
append
(
T
.
PILToTensor
())
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
...
@@ -81,7 +81,7 @@ class SegmentationPresetEval:
...
@@ -81,7 +81,7 @@ class SegmentationPresetEval:
if
backend
==
"tensor"
:
if
backend
==
"tensor"
:
transforms
+=
[
T
.
PILToTensor
()]
transforms
+=
[
T
.
PILToTensor
()]
elif
backend
==
"datapoint"
:
elif
backend
==
"datapoint"
:
transforms
+=
[
T
.
ToImage
Tensor
()]
transforms
+=
[
T
.
ToImage
()]
elif
backend
!=
"pil"
:
elif
backend
!=
"pil"
:
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
raise
ValueError
(
f
"backend can be 'datapoint', 'tensor' or 'pil', but got
{
backend
}
"
)
...
@@ -92,7 +92,7 @@ class SegmentationPresetEval:
...
@@ -92,7 +92,7 @@ class SegmentationPresetEval:
if
backend
==
"pil"
:
if
backend
==
"pil"
:
# Note: we could just convert to pure tensors even in v2?
# Note: we could just convert to pure tensors even in v2?
transforms
+=
[
T
.
ToImage
Tensor
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
T
.
ToImage
()
if
use_v2
else
T
.
PILToTensor
()]
transforms
+=
[
transforms
+=
[
T
.
ConvertImageDtype
(
torch
.
float
),
T
.
ConvertImageDtype
(
torch
.
float
),
...
...
test/common_utils.py
View file @
ca012d39
...
@@ -27,7 +27,7 @@ from PIL import Image
...
@@ -27,7 +27,7 @@ from PIL import Image
from
torch.testing._comparison
import
BooleanPair
,
NonePair
,
not_close_error_metas
,
NumberPair
,
TensorLikePair
from
torch.testing._comparison
import
BooleanPair
,
NonePair
,
not_close_error_metas
,
NumberPair
,
TensorLikePair
from
torchvision
import
datapoints
,
io
from
torchvision
import
datapoints
,
io
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.v2.functional
import
to_dtype_image
_tensor
,
to_image
_pil
,
to_image
_tensor
from
torchvision.transforms.v2.functional
import
to_dtype_image
,
to_image
,
to_
pil_
image
IN_OSS_CI
=
any
(
os
.
getenv
(
var
)
==
"true"
for
var
in
[
"CIRCLECI"
,
"GITHUB_ACTIONS"
])
IN_OSS_CI
=
any
(
os
.
getenv
(
var
)
==
"true"
for
var
in
[
"CIRCLECI"
,
"GITHUB_ACTIONS"
])
...
@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair):
...
@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair):
**
other_parameters
,
**
other_parameters
,
):
):
if
all
(
isinstance
(
input
,
PIL
.
Image
.
Image
)
for
input
in
[
actual
,
expected
]):
if
all
(
isinstance
(
input
,
PIL
.
Image
.
Image
)
for
input
in
[
actual
,
expected
]):
actual
,
expected
=
[
to_image
_tensor
(
input
)
for
input
in
[
actual
,
expected
]]
actual
,
expected
=
[
to_image
(
input
)
for
input
in
[
actual
,
expected
]]
super
().
__init__
(
actual
,
expected
,
**
other_parameters
)
super
().
__init__
(
actual
,
expected
,
**
other_parameters
)
self
.
mae
=
mae
self
.
mae
=
mae
...
@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs):
...
@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs):
def
make_image_pil
(
*
args
,
**
kwargs
):
def
make_image_pil
(
*
args
,
**
kwargs
):
return
to_image
_pil
(
make_image
(
*
args
,
**
kwargs
))
return
to_
pil_
image
(
make_image
(
*
args
,
**
kwargs
))
def
make_image_loader
(
def
make_image_loader
(
...
@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation(
...
@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation(
)
)
)
)
image_tensor
=
to_image
_tensor
(
image_pil
)
image_tensor
=
to_image
(
image_pil
)
if
memory_format
==
torch
.
contiguous_format
:
if
memory_format
==
torch
.
contiguous_format
:
image_tensor
=
image_tensor
.
to
(
device
=
device
,
memory_format
=
memory_format
,
copy
=
True
)
image_tensor
=
image_tensor
.
to
(
device
=
device
,
memory_format
=
memory_format
,
copy
=
True
)
else
:
else
:
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
to_dtype_image
_tensor
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
image_tensor
=
to_dtype_image
(
image_tensor
,
dtype
=
dtype
,
scale
=
True
)
return
datapoints
.
Image
(
image_tensor
)
return
datapoints
.
Image
(
image_tensor
)
...
...
test/test_prototype_transforms.py
View file @
ca012d39
...
@@ -17,7 +17,7 @@ from prototype_common_utils import make_label
...
@@ -17,7 +17,7 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_image
_pil
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_
pil_
image
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference():
...
@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference():
size
=
(
600
,
800
)
size
=
(
600
,
800
)
num_objects
=
22
num_objects
=
22
pil_image
=
to_image
_pil
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
pil_image
=
to_
pil_
image
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
...
test/test_transforms_v2.py
View file @
ca012d39
...
@@ -666,19 +666,19 @@ class TestTransform:
...
@@ -666,19 +666,19 @@ class TestTransform:
t
(
inpt
)
t
(
inpt
)
class
TestToImage
Tensor
:
class
TestToImage
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image
_tensor
"
,
"torchvision.transforms.v2.functional.to_image"
,
return_value
=
torch
.
rand
(
1
,
3
,
8
,
8
),
return_value
=
torch
.
rand
(
1
,
3
,
8
,
8
),
)
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImage
Tensor
()
transform
=
transforms
.
ToImage
()
transform
(
inpt
)
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
BoundingBoxes
,
datapoints
.
Image
,
str
,
int
):
if
inpt_type
in
(
datapoints
.
BoundingBoxes
,
datapoints
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
assert
fn
.
call_count
==
0
...
@@ -686,30 +686,13 @@ class TestToImageTensor:
...
@@ -686,30 +686,13 @@ class TestToImageTensor:
fn
.
assert_called_once_with
(
inpt
)
fn
.
assert_called_once_with
(
inpt
)
class
TestToImagePIL
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image_pil"
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImagePIL
()
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
BoundingBoxes
,
PIL
.
Image
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
else
:
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
class
TestToPILImage
:
class
TestToPILImage
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBoxes
,
str
,
int
],
)
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image
_pil
"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_
pil_
image"
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToPILImage
()
transform
=
transforms
.
ToPILImage
()
...
@@ -1013,7 +996,7 @@ def test_antialias_warning():
...
@@ -1013,7 +996,7 @@ def test_antialias_warning():
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
Tensor
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
image
=
datapoints
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
image
=
datapoints
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
...
@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
...
@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
Tensor
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImage
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
def
test_detection_preset
(
image_type
,
data_augmentation
,
to_tensor
,
sanitize
):
def
test_detection_preset
(
image_type
,
data_augmentation
,
to_tensor
,
sanitize
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
test/test_transforms_v2_consistency.py
View file @
ca012d39
...
@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str
...
@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
from
torchvision.transforms.v2._utils
import
_get_fill
from
torchvision.transforms.v2._utils
import
_get_fill
from
torchvision.transforms.v2.functional
import
to_image
_pil
from
torchvision.transforms.v2.functional
import
to_
pil_
image
from
torchvision.transforms.v2.utils
import
query_size
from
torchvision.transforms.v2.utils
import
query_size
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"RGB"
],
extra_dims
=
[(
4
,)])
DEFAULT_MAKE_IMAGES_KWARGS
=
dict
(
color_spaces
=
[
"RGB"
],
extra_dims
=
[(
4
,)])
...
@@ -630,7 +630,7 @@ def check_call_consistency(
...
@@ -630,7 +630,7 @@ def check_call_consistency(
)
)
if
image
.
ndim
==
3
and
supports_pil
:
if
image
.
ndim
==
3
and
supports_pil
:
image_pil
=
to_image
_pil
(
image
)
image_pil
=
to_
pil_
image
(
image
)
try
:
try
:
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -869,7 +869,7 @@ class TestToTensorTransforms:
...
@@ -869,7 +869,7 @@ class TestToTensorTransforms:
legacy_transform
=
legacy_transforms
.
PILToTensor
()
legacy_transform
=
legacy_transforms
.
PILToTensor
()
for
image
in
make_images
(
extra_dims
=
[()]):
for
image
in
make_images
(
extra_dims
=
[()]):
image_pil
=
to_image
_pil
(
image
)
image_pil
=
to_
pil_
image
(
image
)
assert_equal
(
prototype_transform
(
image_pil
),
legacy_transform
(
image_pil
))
assert_equal
(
prototype_transform
(
image_pil
),
legacy_transform
(
image_pil
))
...
@@ -879,7 +879,7 @@ class TestToTensorTransforms:
...
@@ -879,7 +879,7 @@ class TestToTensorTransforms:
legacy_transform
=
legacy_transforms
.
ToTensor
()
legacy_transform
=
legacy_transforms
.
ToTensor
()
for
image
in
make_images
(
extra_dims
=
[()]):
for
image
in
make_images
(
extra_dims
=
[()]):
image_pil
=
to_image
_pil
(
image
)
image_pil
=
to_
pil_
image
(
image
)
image_numpy
=
np
.
array
(
image_pil
)
image_numpy
=
np
.
array
(
image_pil
)
assert_equal
(
prototype_transform
(
image_pil
),
legacy_transform
(
image_pil
))
assert_equal
(
prototype_transform
(
image_pil
),
legacy_transform
(
image_pil
))
...
@@ -1088,7 +1088,7 @@ class TestRefDetTransforms:
...
@@ -1088,7 +1088,7 @@ class TestRefDetTransforms:
def
make_label
(
extra_dims
,
categories
):
def
make_label
(
extra_dims
,
categories
):
return
torch
.
randint
(
categories
,
extra_dims
,
dtype
=
torch
.
int64
)
return
torch
.
randint
(
categories
,
extra_dims
,
dtype
=
torch
.
int64
)
pil_image
=
to_image
_pil
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
pil_image
=
to_
pil_
image
(
make_image
(
size
=
size
,
color_space
=
"RGB"
))
target
=
{
target
=
{
"boxes"
:
make_bounding_box
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"boxes"
:
make_bounding_box
(
canvas_size
=
size
,
format
=
"XYXY"
,
batch_dims
=
(
num_objects
,),
dtype
=
torch
.
float
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
"labels"
:
make_label
(
extra_dims
=
(
num_objects
,),
categories
=
80
),
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
...
@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
conv_fns
=
[]
conv_fns
=
[]
if
supports_pil
:
if
supports_pil
:
conv_fns
.
append
(
to_image
_pil
)
conv_fns
.
append
(
to_
pil_
image
)
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
conv_fns
.
extend
([
torch
.
Tensor
,
lambda
x
:
x
])
for
conv_fn
in
conv_fns
:
for
conv_fn
in
conv_fns
:
...
@@ -1201,8 +1201,8 @@ class TestRefSegTransforms:
...
@@ -1201,8 +1201,8 @@ class TestRefSegTransforms:
dp
=
(
conv_fn
(
datapoint_image
),
datapoint_mask
)
dp
=
(
conv_fn
(
datapoint_image
),
datapoint_mask
)
dp_ref
=
(
dp_ref
=
(
to_image
_pil
(
datapoint_image
)
if
supports_pil
else
datapoint_image
.
as_subclass
(
torch
.
Tensor
),
to_
pil_
image
(
datapoint_image
)
if
supports_pil
else
datapoint_image
.
as_subclass
(
torch
.
Tensor
),
to_image
_pil
(
datapoint_mask
),
to_
pil_
image
(
datapoint_mask
),
)
)
yield
dp
,
dp_ref
yield
dp
,
dp_ref
...
...
test/test_transforms_v2_functional.py
View file @
ca012d39
...
@@ -280,12 +280,12 @@ class TestKernels:
...
@@ -280,12 +280,12 @@ class TestKernels:
adapted_other_args
,
adapted_kwargs
=
info
.
float32_vs_uint8
(
other_args
,
kwargs
)
adapted_other_args
,
adapted_kwargs
=
info
.
float32_vs_uint8
(
other_args
,
kwargs
)
actual
=
info
.
kernel
(
actual
=
info
.
kernel
(
F
.
to_dtype_image
_tensor
(
input
,
dtype
=
torch
.
float32
,
scale
=
True
),
F
.
to_dtype_image
(
input
,
dtype
=
torch
.
float32
,
scale
=
True
),
*
adapted_other_args
,
*
adapted_other_args
,
**
adapted_kwargs
,
**
adapted_kwargs
,
)
)
expected
=
F
.
to_dtype_image
_tensor
(
info
.
kernel
(
input
,
*
other_args
,
**
kwargs
),
dtype
=
torch
.
float32
,
scale
=
True
)
expected
=
F
.
to_dtype_image
(
info
.
kernel
(
input
,
*
other_args
,
**
kwargs
),
dtype
=
torch
.
float32
,
scale
=
True
)
assert_close
(
assert_close
(
actual
,
actual
,
...
@@ -377,7 +377,7 @@ class TestDispatchers:
...
@@ -377,7 +377,7 @@ class TestDispatchers:
if
image_datapoint
.
ndim
>
3
:
if
image_datapoint
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_image
_pil
(
image_datapoint
)
image_pil
=
F
.
to_
pil_
image
(
image_datapoint
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
...
@@ -470,7 +470,7 @@ class TestDispatchers:
...
@@ -470,7 +470,7 @@ class TestDispatchers:
(
F
.
hflip
,
F
.
horizontal_flip
),
(
F
.
hflip
,
F
.
horizontal_flip
),
(
F
.
vflip
,
F
.
vertical_flip
),
(
F
.
vflip
,
F
.
vertical_flip
),
(
F
.
get_image_num_channels
,
F
.
get_num_channels
),
(
F
.
get_image_num_channels
,
F
.
get_num_channels
),
(
F
.
to_pil_image
,
F
.
to_image
_pil
),
(
F
.
to_pil_image
,
F
.
to_
pil_
image
),
(
F
.
elastic_transform
,
F
.
elastic
),
(
F
.
elastic_transform
,
F
.
elastic
),
(
F
.
to_grayscale
,
F
.
rgb_to_grayscale
),
(
F
.
to_grayscale
,
F
.
rgb_to_grayscale
),
]
]
...
@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
...
@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
mean
=
image
.
mean
(
dim
=
(
1
,
2
)).
tolist
()
mean
=
image
.
mean
(
dim
=
(
1
,
2
)).
tolist
()
std
=
image
.
std
(
dim
=
(
1
,
2
)).
tolist
()
std
=
image
.
std
(
dim
=
(
1
,
2
)).
tolist
()
assert_samples_from_standard_normal
(
F
.
normalize_image
_tensor
(
image
,
mean
,
std
))
assert_samples_from_standard_normal
(
F
.
normalize_image
(
image
,
mean
,
std
))
class
TestClampBoundingBoxes
:
class
TestClampBoundingBoxes
:
...
@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size):
...
@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size):
_
,
image_height
,
image_width
=
mask
.
shape
_
,
image_height
,
image_width
=
mask
.
shape
if
crop_width
>
image_height
or
crop_height
>
image_width
:
if
crop_width
>
image_height
or
crop_height
>
image_width
:
padding
=
_center_crop_compute_padding
(
crop_height
,
crop_width
,
image_height
,
image_width
)
padding
=
_center_crop_compute_padding
(
crop_height
,
crop_width
,
image_height
,
image_width
)
mask
=
F
.
pad_image
_tensor
(
mask
,
padding
,
fill
=
0
)
mask
=
F
.
pad_image
(
mask
,
padding
,
fill
=
0
)
left
=
round
((
image_width
-
crop_width
)
*
0.5
)
left
=
round
((
image_width
-
crop_width
)
*
0.5
)
top
=
round
((
image_height
-
crop_height
)
*
0.5
)
top
=
round
((
image_height
-
crop_height
)
*
0.5
)
...
@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size):
...
@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size):
@
pytest
.
mark
.
parametrize
(
"ksize"
,
[(
3
,
3
),
[
3
,
5
],
(
23
,
23
)])
@
pytest
.
mark
.
parametrize
(
"ksize"
,
[(
3
,
3
),
[
3
,
5
],
(
23
,
23
)])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[[
0.5
,
0.5
],
(
0.5
,
0.5
),
(
0.8
,
0.8
),
(
1.7
,
1.7
)])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[[
0.5
,
0.5
],
(
0.5
,
0.5
),
(
0.8
,
0.8
),
(
1.7
,
1.7
)])
def
test_correctness_gaussian_blur_image_tensor
(
device
,
canvas_size
,
dt
,
ksize
,
sigma
):
def
test_correctness_gaussian_blur_image_tensor
(
device
,
canvas_size
,
dt
,
ksize
,
sigma
):
fn
=
F
.
gaussian_blur_image
_tensor
fn
=
F
.
gaussian_blur_image
# true_cv2_results = {
# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
...
@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
...
@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
PIL
.
Image
.
new
(
"RGB"
,
(
32
,
32
),
122
),
PIL
.
Image
.
new
(
"RGB"
,
(
32
,
32
),
122
),
],
],
)
)
def
test_to_image
_tensor
(
inpt
):
def
test_to_image
(
inpt
):
output
=
F
.
to_image
_tensor
(
inpt
)
output
=
F
.
to_image
(
inpt
)
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
output
.
shape
==
(
3
,
32
,
32
)
assert
output
.
shape
==
(
3
,
32
,
32
)
...
@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt):
...
@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt):
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
None
,
"RGB"
])
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
None
,
"RGB"
])
def
test_to_image
_pil
(
inpt
,
mode
):
def
test_to_
pil_
image
(
inpt
,
mode
):
output
=
F
.
to_image
_pil
(
inpt
,
mode
=
mode
)
output
=
F
.
to_
pil_
image
(
inpt
,
mode
=
mode
)
assert
isinstance
(
output
,
PIL
.
Image
.
Image
)
assert
isinstance
(
output
,
PIL
.
Image
.
Image
)
assert
np
.
asarray
(
inpt
).
sum
()
==
np
.
asarray
(
output
).
sum
()
assert
np
.
asarray
(
inpt
).
sum
()
==
np
.
asarray
(
output
).
sum
()
...
@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode):
...
@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode):
def
test_equalize_image_tensor_edge_cases
():
def
test_equalize_image_tensor_edge_cases
():
inpt
=
torch
.
zeros
(
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
inpt
=
torch
.
zeros
(
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
output
=
F
.
equalize_image
_tensor
(
inpt
)
output
=
F
.
equalize_image
(
inpt
)
torch
.
testing
.
assert_close
(
inpt
,
output
)
torch
.
testing
.
assert_close
(
inpt
,
output
)
inpt
=
torch
.
zeros
(
5
,
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
inpt
=
torch
.
zeros
(
5
,
3
,
200
,
200
,
dtype
=
torch
.
uint8
)
inpt
[...,
100
:,
100
:]
=
1
inpt
[...,
100
:,
100
:]
=
1
output
=
F
.
equalize_image
_tensor
(
inpt
)
output
=
F
.
equalize_image
(
inpt
)
assert
output
.
unique
().
tolist
()
==
[
0
,
255
]
assert
output
.
unique
().
tolist
()
==
[
0
,
255
]
...
@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device):
...
@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device):
# TODO: We can remove this test and related torchvision workaround
# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@
make_info_args_kwargs_parametrization
(
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
KERNEL_INFOS
if
info
.
kernel
is
F
.
resize_image
_tensor
],
[
info
for
info
in
KERNEL_INFOS
if
info
.
kernel
is
F
.
resize_image
],
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
)
)
def
test_memory_format_consistency_resize_image_tensor
(
test_id
,
info
,
args_kwargs
):
def
test_memory_format_consistency_resize_image_tensor
(
test_id
,
info
,
args_kwargs
):
...
...
test/test_transforms_v2_refactored.py
View file @
ca012d39
...
@@ -437,7 +437,7 @@ class TestResize:
...
@@ -437,7 +437,7 @@ class TestResize:
check_cuda_vs_cpu_tolerances
=
dict
(
rtol
=
0
,
atol
=
atol
/
255
if
dtype
.
is_floating_point
else
atol
)
check_cuda_vs_cpu_tolerances
=
dict
(
rtol
=
0
,
atol
=
atol
/
255
if
dtype
.
is_floating_point
else
atol
)
check_kernel
(
check_kernel
(
F
.
resize_image
_tensor
,
F
.
resize_image
,
make_image
(
self
.
INPUT_SIZE
,
dtype
=
dtype
,
device
=
device
),
make_image
(
self
.
INPUT_SIZE
,
dtype
=
dtype
,
device
=
device
),
size
=
size
,
size
=
size
,
interpolation
=
interpolation
,
interpolation
=
interpolation
,
...
@@ -495,9 +495,9 @@ class TestResize:
...
@@ -495,9 +495,9 @@ class TestResize:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
resize_image
_tensor
,
torch
.
Tensor
),
(
F
.
resize_image
,
torch
.
Tensor
),
(
F
.
resize_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
resize_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
resize_image
_tensor
,
datapoints
.
Image
),
(
F
.
resize_image
,
datapoints
.
Image
),
(
F
.
resize_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
resize_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
resize_mask
,
datapoints
.
Mask
),
(
F
.
resize_mask
,
datapoints
.
Mask
),
(
F
.
resize_video
,
datapoints
.
Video
),
(
F
.
resize_video
,
datapoints
.
Video
),
...
@@ -541,9 +541,7 @@ class TestResize:
...
@@ -541,9 +541,7 @@ class TestResize:
image
=
make_image
(
self
.
INPUT_SIZE
,
dtype
=
torch
.
uint8
)
image
=
make_image
(
self
.
INPUT_SIZE
,
dtype
=
torch
.
uint8
)
actual
=
fn
(
image
,
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
,
antialias
=
True
)
actual
=
fn
(
image
,
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
,
antialias
=
True
)
expected
=
F
.
to_image_tensor
(
expected
=
F
.
to_image
(
F
.
resize
(
F
.
to_pil_image
(
image
),
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
))
F
.
resize
(
F
.
to_image_pil
(
image
),
size
=
size
,
interpolation
=
interpolation
,
**
max_size_kwarg
)
)
self
.
_check_output_size
(
image
,
actual
,
size
=
size
,
**
max_size_kwarg
)
self
.
_check_output_size
(
image
,
actual
,
size
=
size
,
**
max_size_kwarg
)
torch
.
testing
.
assert_close
(
actual
,
expected
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
actual
,
expected
,
atol
=
1
,
rtol
=
0
)
...
@@ -739,7 +737,7 @@ class TestHorizontalFlip:
...
@@ -739,7 +737,7 @@ class TestHorizontalFlip:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
check_kernel
(
F
.
horizontal_flip_image
_tensor
,
make_image
(
dtype
=
dtype
,
device
=
device
))
check_kernel
(
F
.
horizontal_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoints
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoints
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
...
@@ -770,9 +768,9 @@ class TestHorizontalFlip:
...
@@ -770,9 +768,9 @@ class TestHorizontalFlip:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
horizontal_flip_image
_tensor
,
torch
.
Tensor
),
(
F
.
horizontal_flip_image
,
torch
.
Tensor
),
(
F
.
horizontal_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
horizontal_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
horizontal_flip_image
_tensor
,
datapoints
.
Image
),
(
F
.
horizontal_flip_image
,
datapoints
.
Image
),
(
F
.
horizontal_flip_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
horizontal_flip_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
horizontal_flip_mask
,
datapoints
.
Mask
),
(
F
.
horizontal_flip_mask
,
datapoints
.
Mask
),
(
F
.
horizontal_flip_video
,
datapoints
.
Video
),
(
F
.
horizontal_flip_video
,
datapoints
.
Video
),
...
@@ -796,7 +794,7 @@ class TestHorizontalFlip:
...
@@ -796,7 +794,7 @@ class TestHorizontalFlip:
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
fn
(
image
)
actual
=
fn
(
image
)
expected
=
F
.
to_image
_tensor
(
F
.
horizontal_flip
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
F
.
horizontal_flip
(
F
.
to_
pil_
image
(
image
)))
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
...
@@ -900,7 +898,7 @@ class TestAffine:
...
@@ -900,7 +898,7 @@ class TestAffine:
if
param
==
"fill"
:
if
param
==
"fill"
:
value
=
adapt_fill
(
value
,
dtype
=
dtype
)
value
=
adapt_fill
(
value
,
dtype
=
dtype
)
self
.
_check_kernel
(
self
.
_check_kernel
(
F
.
affine_image
_tensor
,
F
.
affine_image
,
make_image
(
dtype
=
dtype
,
device
=
device
),
make_image
(
dtype
=
dtype
,
device
=
device
),
**
{
param
:
value
},
**
{
param
:
value
},
check_scripted_vs_eager
=
not
(
param
in
{
"shear"
,
"fill"
}
and
isinstance
(
value
,
(
int
,
float
))),
check_scripted_vs_eager
=
not
(
param
in
{
"shear"
,
"fill"
}
and
isinstance
(
value
,
(
int
,
float
))),
...
@@ -946,9 +944,9 @@ class TestAffine:
...
@@ -946,9 +944,9 @@ class TestAffine:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
affine_image
_tensor
,
torch
.
Tensor
),
(
F
.
affine_image
,
torch
.
Tensor
),
(
F
.
affine_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
affine_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
affine_image
_tensor
,
datapoints
.
Image
),
(
F
.
affine_image
,
datapoints
.
Image
),
(
F
.
affine_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
affine_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
affine_mask
,
datapoints
.
Mask
),
(
F
.
affine_mask
,
datapoints
.
Mask
),
(
F
.
affine_video
,
datapoints
.
Video
),
(
F
.
affine_video
,
datapoints
.
Video
),
...
@@ -991,9 +989,9 @@ class TestAffine:
...
@@ -991,9 +989,9 @@ class TestAffine:
interpolation
=
interpolation
,
interpolation
=
interpolation
,
fill
=
fill
,
fill
=
fill
,
)
)
expected
=
F
.
to_image
_tensor
(
expected
=
F
.
to_image
(
F
.
affine
(
F
.
affine
(
F
.
to_image
_pil
(
image
),
F
.
to_
pil_
image
(
image
),
angle
=
angle
,
angle
=
angle
,
translate
=
translate
,
translate
=
translate
,
scale
=
scale
,
scale
=
scale
,
...
@@ -1026,7 +1024,7 @@ class TestAffine:
...
@@ -1026,7 +1024,7 @@ class TestAffine:
actual
=
transform
(
image
)
actual
=
transform
(
image
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
expected
=
F
.
to_image
_tensor
(
transform
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
transform
(
F
.
to_
pil_
image
(
image
)))
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
2
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
8
assert
mae
<
2
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
8
...
@@ -1204,7 +1202,7 @@ class TestVerticalFlip:
...
@@ -1204,7 +1202,7 @@ class TestVerticalFlip:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
def
test_kernel_image_tensor
(
self
,
dtype
,
device
):
check_kernel
(
F
.
vertical_flip_image
_tensor
,
make_image
(
dtype
=
dtype
,
device
=
device
))
check_kernel
(
F
.
vertical_flip_image
,
make_image
(
dtype
=
dtype
,
device
=
device
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoints
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
datapoints
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
int64
])
...
@@ -1235,9 +1233,9 @@ class TestVerticalFlip:
...
@@ -1235,9 +1233,9 @@ class TestVerticalFlip:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
vertical_flip_image
_tensor
,
torch
.
Tensor
),
(
F
.
vertical_flip_image
,
torch
.
Tensor
),
(
F
.
vertical_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
vertical_flip_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
vertical_flip_image
_tensor
,
datapoints
.
Image
),
(
F
.
vertical_flip_image
,
datapoints
.
Image
),
(
F
.
vertical_flip_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
vertical_flip_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
vertical_flip_mask
,
datapoints
.
Mask
),
(
F
.
vertical_flip_mask
,
datapoints
.
Mask
),
(
F
.
vertical_flip_video
,
datapoints
.
Video
),
(
F
.
vertical_flip_video
,
datapoints
.
Video
),
...
@@ -1259,7 +1257,7 @@ class TestVerticalFlip:
...
@@ -1259,7 +1257,7 @@ class TestVerticalFlip:
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
fn
(
image
)
actual
=
fn
(
image
)
expected
=
F
.
to_image
_tensor
(
F
.
vertical_flip
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
F
.
vertical_flip
(
F
.
to_
pil_
image
(
image
)))
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
...
@@ -1339,7 +1337,7 @@ class TestRotate:
...
@@ -1339,7 +1337,7 @@ class TestRotate:
if
param
!=
"angle"
:
if
param
!=
"angle"
:
kwargs
[
"angle"
]
=
self
.
_MINIMAL_AFFINE_KWARGS
[
"angle"
]
kwargs
[
"angle"
]
=
self
.
_MINIMAL_AFFINE_KWARGS
[
"angle"
]
check_kernel
(
check_kernel
(
F
.
rotate_image
_tensor
,
F
.
rotate_image
,
make_image
(
dtype
=
dtype
,
device
=
device
),
make_image
(
dtype
=
dtype
,
device
=
device
),
**
kwargs
,
**
kwargs
,
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
...
@@ -1385,9 +1383,9 @@ class TestRotate:
...
@@ -1385,9 +1383,9 @@ class TestRotate:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
rotate_image
_tensor
,
torch
.
Tensor
),
(
F
.
rotate_image
,
torch
.
Tensor
),
(
F
.
rotate_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
rotate_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
rotate_image
_tensor
,
datapoints
.
Image
),
(
F
.
rotate_image
,
datapoints
.
Image
),
(
F
.
rotate_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
rotate_bounding_boxes
,
datapoints
.
BoundingBoxes
),
(
F
.
rotate_mask
,
datapoints
.
Mask
),
(
F
.
rotate_mask
,
datapoints
.
Mask
),
(
F
.
rotate_video
,
datapoints
.
Video
),
(
F
.
rotate_video
,
datapoints
.
Video
),
...
@@ -1419,9 +1417,9 @@ class TestRotate:
...
@@ -1419,9 +1417,9 @@ class TestRotate:
fill
=
adapt_fill
(
fill
,
dtype
=
torch
.
uint8
)
fill
=
adapt_fill
(
fill
,
dtype
=
torch
.
uint8
)
actual
=
F
.
rotate
(
image
,
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
)
actual
=
F
.
rotate
(
image
,
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
)
expected
=
F
.
to_image
_tensor
(
expected
=
F
.
to_image
(
F
.
rotate
(
F
.
rotate
(
F
.
to_image
_pil
(
image
),
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
F
.
to_
pil_
image
(
image
),
angle
=
angle
,
center
=
center
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
)
)
)
)
...
@@ -1452,7 +1450,7 @@ class TestRotate:
...
@@ -1452,7 +1450,7 @@ class TestRotate:
actual
=
transform
(
image
)
actual
=
transform
(
image
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
expected
=
F
.
to_image
_tensor
(
transform
(
F
.
to_image
_pil
(
image
)))
expected
=
F
.
to_image
(
transform
(
F
.
to_
pil_
image
(
image
)))
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
1
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
6
assert
mae
<
1
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
6
...
@@ -1621,8 +1619,8 @@ class TestToDtype:
...
@@ -1621,8 +1619,8 @@ class TestToDtype:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
(
"kernel"
,
"make_input"
),
[
[
(
F
.
to_dtype_image
_tensor
,
make_image_tensor
),
(
F
.
to_dtype_image
,
make_image_tensor
),
(
F
.
to_dtype_image
_tensor
,
make_image
),
(
F
.
to_dtype_image
,
make_image
),
(
F
.
to_dtype_video
,
make_video
),
(
F
.
to_dtype_video
,
make_video
),
],
],
)
)
...
@@ -1801,7 +1799,7 @@ class TestAdjustBrightness:
...
@@ -1801,7 +1799,7 @@ class TestAdjustBrightness:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
(
"kernel"
,
"make_input"
),
[
[
(
F
.
adjust_brightness_image
_tensor
,
make_image
),
(
F
.
adjust_brightness_image
,
make_image
),
(
F
.
adjust_brightness_video
,
make_video
),
(
F
.
adjust_brightness_video
,
make_video
),
],
],
)
)
...
@@ -1817,9 +1815,9 @@ class TestAdjustBrightness:
...
@@ -1817,9 +1815,9 @@ class TestAdjustBrightness:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
adjust_brightness_image
_tensor
,
torch
.
Tensor
),
(
F
.
adjust_brightness_image
,
torch
.
Tensor
),
(
F
.
adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
adjust_brightness_image
_tensor
,
datapoints
.
Image
),
(
F
.
adjust_brightness_image
,
datapoints
.
Image
),
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
],
],
)
)
...
@@ -1831,7 +1829,7 @@ class TestAdjustBrightness:
...
@@ -1831,7 +1829,7 @@ class TestAdjustBrightness:
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
F
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
actual
=
F
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
expected
=
F
.
to_image
_tensor
(
F
.
adjust_brightness
(
F
.
to_image
_pil
(
image
),
brightness_factor
=
brightness_factor
))
expected
=
F
.
to_image
(
F
.
adjust_brightness
(
F
.
to_
pil_
image
(
image
),
brightness_factor
=
brightness_factor
))
torch
.
testing
.
assert_close
(
actual
,
expected
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
...
@@ -1979,9 +1977,9 @@ class TestShapeGetters:
...
@@ -1979,9 +1977,9 @@ class TestShapeGetters:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
(
"kernel"
,
"make_input"
),
[
[
(
F
.
get_dimensions_image
_tensor
,
make_image_tensor
),
(
F
.
get_dimensions_image
,
make_image_tensor
),
(
F
.
get_dimensions_image_pil
,
make_image_pil
),
(
F
.
_
get_dimensions_image_pil
,
make_image_pil
),
(
F
.
get_dimensions_image
_tensor
,
make_image
),
(
F
.
get_dimensions_image
,
make_image
),
(
F
.
get_dimensions_video
,
make_video
),
(
F
.
get_dimensions_video
,
make_video
),
],
],
)
)
...
@@ -1996,9 +1994,9 @@ class TestShapeGetters:
...
@@ -1996,9 +1994,9 @@ class TestShapeGetters:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
(
"kernel"
,
"make_input"
),
[
[
(
F
.
get_num_channels_image
_tensor
,
make_image_tensor
),
(
F
.
get_num_channels_image
,
make_image_tensor
),
(
F
.
get_num_channels_image_pil
,
make_image_pil
),
(
F
.
_
get_num_channels_image_pil
,
make_image_pil
),
(
F
.
get_num_channels_image
_tensor
,
make_image
),
(
F
.
get_num_channels_image
,
make_image
),
(
F
.
get_num_channels_video
,
make_video
),
(
F
.
get_num_channels_video
,
make_video
),
],
],
)
)
...
@@ -2012,9 +2010,9 @@ class TestShapeGetters:
...
@@ -2012,9 +2010,9 @@ class TestShapeGetters:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
(
"kernel"
,
"make_input"
),
[
[
(
F
.
get_size_image
_tensor
,
make_image_tensor
),
(
F
.
get_size_image
,
make_image_tensor
),
(
F
.
get_size_image_pil
,
make_image_pil
),
(
F
.
_
get_size_image_pil
,
make_image_pil
),
(
F
.
get_size_image
_tensor
,
make_image
),
(
F
.
get_size_image
,
make_image
),
(
F
.
get_size_bounding_boxes
,
make_bounding_box
),
(
F
.
get_size_bounding_boxes
,
make_bounding_box
),
(
F
.
get_size_mask
,
make_detection_mask
),
(
F
.
get_size_mask
,
make_detection_mask
),
(
F
.
get_size_mask
,
make_segmentation_mask
),
(
F
.
get_size_mask
,
make_segmentation_mask
),
...
@@ -2101,7 +2099,7 @@ class TestRegisterKernel:
...
@@ -2101,7 +2099,7 @@ class TestRegisterKernel:
F
.
register_kernel
(
F
.
resize
,
object
)
F
.
register_kernel
(
F
.
resize
,
object
)
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be registered for the builtin datapoint classes"
):
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be registered for the builtin datapoint classes"
):
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image
_tensor
)
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image
)
class
CustomDatapoint
(
datapoints
.
Datapoint
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
pass
...
@@ -2119,9 +2117,9 @@ class TestGetKernel:
...
@@ -2119,9 +2117,9 @@ class TestGetKernel:
# We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# would also be fine
# would also be fine
KERNELS
=
{
KERNELS
=
{
torch
.
Tensor
:
F
.
resize_image
_tensor
,
torch
.
Tensor
:
F
.
resize_image
,
PIL
.
Image
.
Image
:
F
.
resize_image_pil
,
PIL
.
Image
.
Image
:
F
.
_
resize_image_pil
,
datapoints
.
Image
:
F
.
resize_image
_tensor
,
datapoints
.
Image
:
F
.
resize_image
,
datapoints
.
BoundingBoxes
:
F
.
resize_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
resize_bounding_boxes
,
datapoints
.
Mask
:
F
.
resize_mask
,
datapoints
.
Mask
:
F
.
resize_mask
,
datapoints
.
Video
:
F
.
resize_video
,
datapoints
.
Video
:
F
.
resize_video
,
...
@@ -2217,10 +2215,10 @@ class TestPermuteChannels:
...
@@ -2217,10 +2215,10 @@ class TestPermuteChannels:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
(
"kernel"
,
"make_input"
),
[
[
(
F
.
permute_channels_image
_tensor
,
make_image_tensor
),
(
F
.
permute_channels_image
,
make_image_tensor
),
# FIXME
# FIXME
# check_kernel does not support PIL kernel, but it should
# check_kernel does not support PIL kernel, but it should
(
F
.
permute_channels_image
_tensor
,
make_image
),
(
F
.
permute_channels_image
,
make_image
),
(
F
.
permute_channels_video
,
make_video
),
(
F
.
permute_channels_video
,
make_video
),
],
],
)
)
...
@@ -2236,9 +2234,9 @@ class TestPermuteChannels:
...
@@ -2236,9 +2234,9 @@ class TestPermuteChannels:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
[
[
(
F
.
permute_channels_image
_tensor
,
torch
.
Tensor
),
(
F
.
permute_channels_image
,
torch
.
Tensor
),
(
F
.
permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
_
permute_channels_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
permute_channels_image
_tensor
,
datapoints
.
Image
),
(
F
.
permute_channels_image
,
datapoints
.
Image
),
(
F
.
permute_channels_video
,
datapoints
.
Video
),
(
F
.
permute_channels_video
,
datapoints
.
Video
),
],
],
)
)
...
...
test/test_transforms_v2_utils.py
View file @
ca012d39
...
@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils
...
@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_box
,
make_detection_mask
,
make_image
from
common_utils
import
DEFAULT_SIZE
,
make_bounding_box
,
make_detection_mask
,
make_image
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.transforms.v2.functional
import
to_image
_pil
from
torchvision.transforms.v2.functional
import
to_
pil_
image
from
torchvision.transforms.v2.utils
import
has_all
,
has_any
from
torchvision.transforms.v2.utils
import
has_all
,
has_any
...
@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE)
...
@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE)
True
,
True
,
),
),
(
(
(
to_image
_pil
(
IMAGE
),),
(
to_
pil_
image
(
IMAGE
),),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_simple_tensor
),
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
torchvision
.
transforms
.
v2
.
utils
.
is_simple_tensor
),
True
,
True
,
),
),
...
...
test/transforms_v2_dispatcher_infos.py
View file @
ca012d39
...
@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [
...
@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
crop
,
F
.
crop
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
crop_image
_tensor
,
datapoints
.
Image
:
F
.
crop_image
,
datapoints
.
Video
:
F
.
crop_video
,
datapoints
.
Video
:
F
.
crop_video
,
datapoints
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
crop_bounding_boxes
,
datapoints
.
Mask
:
F
.
crop_mask
,
datapoints
.
Mask
:
F
.
crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
resized_crop
,
F
.
resized_crop
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
resized_crop_image
_tensor
,
datapoints
.
Image
:
F
.
resized_crop_image
,
datapoints
.
Video
:
F
.
resized_crop_video
,
datapoints
.
Video
:
F
.
resized_crop_video
,
datapoints
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
resized_crop_bounding_boxes
,
datapoints
.
Mask
:
F
.
resized_crop_mask
,
datapoints
.
Mask
:
F
.
resized_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
resized_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
resized_crop_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
pad
,
F
.
pad
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
pad_image
_tensor
,
datapoints
.
Image
:
F
.
pad_image
,
datapoints
.
Video
:
F
.
pad_video
,
datapoints
.
Video
:
F
.
pad_video
,
datapoints
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
pad_bounding_boxes
,
datapoints
.
Mask
:
F
.
pad_mask
,
datapoints
.
Mask
:
F
.
pad_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
test_marks
=
[
test_marks
=
[
*
xfails_pil
(
*
xfails_pil
(
reason
=
(
reason
=
(
...
@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [
...
@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
perspective
,
F
.
perspective
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
perspective_image
_tensor
,
datapoints
.
Image
:
F
.
perspective_image
,
datapoints
.
Video
:
F
.
perspective_video
,
datapoints
.
Video
:
F
.
perspective_video
,
datapoints
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
perspective_bounding_boxes
,
datapoints
.
Mask
:
F
.
perspective_mask
,
datapoints
.
Mask
:
F
.
perspective_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
perspective_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
perspective_image_pil
),
test_marks
=
[
test_marks
=
[
*
xfails_pil_if_fill_sequence_needs_broadcast
,
*
xfails_pil_if_fill_sequence_needs_broadcast
,
xfail_jit_python_scalar_arg
(
"fill"
),
xfail_jit_python_scalar_arg
(
"fill"
),
...
@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [
...
@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
elastic
,
F
.
elastic
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
elastic_image
_tensor
,
datapoints
.
Image
:
F
.
elastic_image
,
datapoints
.
Video
:
F
.
elastic_video
,
datapoints
.
Video
:
F
.
elastic_video
,
datapoints
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
elastic_bounding_boxes
,
datapoints
.
Mask
:
F
.
elastic_mask
,
datapoints
.
Mask
:
F
.
elastic_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
elastic_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
elastic_image_pil
),
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
center_crop
,
F
.
center_crop
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
center_crop_image
_tensor
,
datapoints
.
Image
:
F
.
center_crop_image
,
datapoints
.
Video
:
F
.
center_crop_video
,
datapoints
.
Video
:
F
.
center_crop_video
,
datapoints
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
datapoints
.
BoundingBoxes
:
F
.
center_crop_bounding_boxes
,
datapoints
.
Mask
:
F
.
center_crop_mask
,
datapoints
.
Mask
:
F
.
center_crop_mask
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
center_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
center_crop_image_pil
),
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"output_size"
),
xfail_jit_python_scalar_arg
(
"output_size"
),
],
],
...
@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [
...
@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
gaussian_blur
,
F
.
gaussian_blur
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
gaussian_blur_image
_tensor
,
datapoints
.
Image
:
F
.
gaussian_blur_image
,
datapoints
.
Video
:
F
.
gaussian_blur_video
,
datapoints
.
Video
:
F
.
gaussian_blur_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
gaussian_blur_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
gaussian_blur_image_pil
),
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"kernel_size"
),
xfail_jit_python_scalar_arg
(
"kernel_size"
),
xfail_jit_python_scalar_arg
(
"sigma"
),
xfail_jit_python_scalar_arg
(
"sigma"
),
...
@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [
...
@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
equalize
,
F
.
equalize
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
equalize_image
_tensor
,
datapoints
.
Image
:
F
.
equalize_image
,
datapoints
.
Video
:
F
.
equalize_video
,
datapoints
.
Video
:
F
.
equalize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
equalize_image_pil
,
kernel_name
=
"equalize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
invert
,
F
.
invert
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
invert_image
_tensor
,
datapoints
.
Image
:
F
.
invert_image
,
datapoints
.
Video
:
F
.
invert_video
,
datapoints
.
Video
:
F
.
invert_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
invert_image_pil
,
kernel_name
=
"invert_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
posterize
,
F
.
posterize
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
posterize_image
_tensor
,
datapoints
.
Image
:
F
.
posterize_image
,
datapoints
.
Video
:
F
.
posterize_video
,
datapoints
.
Video
:
F
.
posterize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
posterize_image_pil
,
kernel_name
=
"posterize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
solarize
,
F
.
solarize
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
solarize_image
_tensor
,
datapoints
.
Image
:
F
.
solarize_image
,
datapoints
.
Video
:
F
.
solarize_video
,
datapoints
.
Video
:
F
.
solarize_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
solarize_image_pil
,
kernel_name
=
"solarize_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
autocontrast
,
F
.
autocontrast
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
autocontrast_image
_tensor
,
datapoints
.
Image
:
F
.
autocontrast_image
,
datapoints
.
Video
:
F
.
autocontrast_video
,
datapoints
.
Video
:
F
.
autocontrast_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
autocontrast_image_pil
,
kernel_name
=
"autocontrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_sharpness
,
F
.
adjust_sharpness
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
adjust_sharpness_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_sharpness_image
,
datapoints
.
Video
:
F
.
adjust_sharpness_video
,
datapoints
.
Video
:
F
.
adjust_sharpness_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_sharpness_image_pil
,
kernel_name
=
"adjust_sharpness_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
erase
,
F
.
erase
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
erase_image
_tensor
,
datapoints
.
Image
:
F
.
erase_image
,
datapoints
.
Video
:
F
.
erase_video
,
datapoints
.
Video
:
F
.
erase_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
erase_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
erase_image_pil
),
test_marks
=
[
test_marks
=
[
skip_dispatch_datapoint
,
skip_dispatch_datapoint
,
],
],
...
@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [
...
@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_contrast
,
F
.
adjust_contrast
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
adjust_contrast_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_contrast_image
,
datapoints
.
Video
:
F
.
adjust_contrast_video
,
datapoints
.
Video
:
F
.
adjust_contrast_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_contrast_image_pil
,
kernel_name
=
"adjust_contrast_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_gamma
,
F
.
adjust_gamma
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
adjust_gamma_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_gamma_image
,
datapoints
.
Video
:
F
.
adjust_gamma_video
,
datapoints
.
Video
:
F
.
adjust_gamma_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_gamma_image_pil
,
kernel_name
=
"adjust_gamma_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_hue
,
F
.
adjust_hue
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
adjust_hue_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_hue_image
,
datapoints
.
Video
:
F
.
adjust_hue_video
,
datapoints
.
Video
:
F
.
adjust_hue_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_hue_image_pil
,
kernel_name
=
"adjust_hue_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
adjust_saturation
,
F
.
adjust_saturation
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
adjust_saturation_image
_tensor
,
datapoints
.
Image
:
F
.
adjust_saturation_image
,
datapoints
.
Video
:
F
.
adjust_saturation_video
,
datapoints
.
Video
:
F
.
adjust_saturation_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
adjust_saturation_image_pil
,
kernel_name
=
"adjust_saturation_image_pil"
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
five_crop
,
F
.
five_crop
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
five_crop_image
_tensor
,
datapoints
.
Image
:
F
.
five_crop_image
,
datapoints
.
Video
:
F
.
five_crop_video
,
datapoints
.
Video
:
F
.
five_crop_video
,
},
},
pil_kernel_info
=
PILKernelInfo
(
F
.
five_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
five_crop_image_pil
),
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
*
multi_crop_skips
,
*
multi_crop_skips
,
...
@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [
...
@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [
DispatcherInfo
(
DispatcherInfo
(
F
.
ten_crop
,
F
.
ten_crop
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
ten_crop_image
_tensor
,
datapoints
.
Image
:
F
.
ten_crop_image
,
datapoints
.
Video
:
F
.
ten_crop_video
,
datapoints
.
Video
:
F
.
ten_crop_video
,
},
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"size"
),
xfail_jit_python_scalar_arg
(
"size"
),
*
multi_crop_skips
,
*
multi_crop_skips
,
],
],
pil_kernel_info
=
PILKernelInfo
(
F
.
ten_crop_image_pil
),
pil_kernel_info
=
PILKernelInfo
(
F
.
_
ten_crop_image_pil
),
),
),
DispatcherInfo
(
DispatcherInfo
(
F
.
normalize
,
F
.
normalize
,
kernels
=
{
kernels
=
{
datapoints
.
Image
:
F
.
normalize_image
_tensor
,
datapoints
.
Image
:
F
.
normalize_image
,
datapoints
.
Video
:
F
.
normalize_video
,
datapoints
.
Video
:
F
.
normalize_video
,
},
},
test_marks
=
[
test_marks
=
[
...
...
test/transforms_v2_kernel_infos.py
View file @
ca012d39
...
@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel):
...
@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel):
f
"Can only test single tensor images against PIL, but input has shape
{
input_tensor
.
shape
}
"
f
"Can only test single tensor images against PIL, but input has shape
{
input_tensor
.
shape
}
"
)
)
input_pil
=
F
.
to_image
_pil
(
input_tensor
)
input_pil
=
F
.
to_
pil_
image
(
input_tensor
)
output_pil
=
pil_kernel
(
input_pil
,
*
other_args
,
**
kwargs
)
output_pil
=
pil_kernel
(
input_pil
,
*
other_args
,
**
kwargs
)
if
not
isinstance
(
output_pil
,
PIL
.
Image
.
Image
):
if
not
isinstance
(
output_pil
,
PIL
.
Image
.
Image
):
return
output_pil
return
output_pil
output_tensor
=
F
.
to_image
_tensor
(
output_pil
)
output_tensor
=
F
.
to_image
(
output_pil
)
# 2D mask shenanigans
# 2D mask shenanigans
if
output_tensor
.
ndim
==
2
and
input_tensor
.
ndim
==
3
:
if
output_tensor
.
ndim
==
2
and
input_tensor
.
ndim
==
3
:
...
@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes():
...
@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
crop_image
_tensor
,
F
.
crop_image
,
kernel_name
=
"crop_image_tensor"
,
kernel_name
=
"crop_image_tensor"
,
sample_inputs_fn
=
sample_inputs_crop_image_tensor
,
sample_inputs_fn
=
sample_inputs_crop_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
crop_image_pil
),
reference_inputs_fn
=
reference_inputs_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_crop_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
),
),
...
@@ -347,7 +347,7 @@ KERNEL_INFOS.extend(
...
@@ -347,7 +347,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
crop_mask
,
F
.
crop_mask
,
sample_inputs_fn
=
sample_inputs_crop_mask
,
sample_inputs_fn
=
sample_inputs_crop_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
crop_image_pil
),
reference_inputs_fn
=
reference_inputs_crop_mask
,
reference_inputs_fn
=
reference_inputs_crop_mask
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
),
),
...
@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
...
@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
F
.
InterpolationMode
.
BICUBIC
,
F
.
InterpolationMode
.
BICUBIC
,
}:
}:
raise
pytest
.
UsageError
(
"Anti-aliasing is always active in PIL"
)
raise
pytest
.
UsageError
(
"Anti-aliasing is always active in PIL"
)
return
F
.
resized_crop_image_pil
(
*
args
,
**
kwargs
)
return
F
.
_
resized_crop_image_pil
(
*
args
,
**
kwargs
)
def
reference_inputs_resized_crop_image_tensor
():
def
reference_inputs_resized_crop_image_tensor
():
...
@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video():
...
@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
resized_crop_image
_tensor
,
F
.
resized_crop_image
,
sample_inputs_fn
=
sample_inputs_resized_crop_image_tensor
,
sample_inputs_fn
=
sample_inputs_resized_crop_image_tensor
,
reference_fn
=
reference_resized_crop_image_tensor
,
reference_fn
=
reference_resized_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_resized_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_resized_crop_image_tensor
,
...
@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs):
...
@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs):
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
pad_image
_tensor
,
F
.
pad_image
,
sample_inputs_fn
=
sample_inputs_pad_image_tensor
,
sample_inputs_fn
=
sample_inputs_pad_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
pad_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
pad_image_pil
),
reference_inputs_fn
=
reference_inputs_pad_image_tensor
,
reference_inputs_fn
=
reference_inputs_pad_image_tensor
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
...
@@ -595,7 +595,7 @@ KERNEL_INFOS.extend(
...
@@ -595,7 +595,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
pad_mask
,
F
.
pad_mask
,
sample_inputs_fn
=
sample_inputs_pad_mask
,
sample_inputs_fn
=
sample_inputs_pad_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
pad_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
pad_image_pil
),
reference_inputs_fn
=
reference_inputs_pad_mask
,
reference_inputs_fn
=
reference_inputs_pad_mask
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
),
),
...
@@ -690,9 +690,9 @@ def sample_inputs_perspective_video():
...
@@ -690,9 +690,9 @@ def sample_inputs_perspective_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
perspective_image
_tensor
,
F
.
perspective_image
,
sample_inputs_fn
=
sample_inputs_perspective_image_tensor
,
sample_inputs_fn
=
sample_inputs_perspective_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
perspective_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
perspective_image_pil
),
reference_inputs_fn
=
reference_inputs_perspective_image_tensor
,
reference_inputs_fn
=
reference_inputs_perspective_image_tensor
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -715,7 +715,7 @@ KERNEL_INFOS.extend(
...
@@ -715,7 +715,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
perspective_mask
,
F
.
perspective_mask
,
sample_inputs_fn
=
sample_inputs_perspective_mask
,
sample_inputs_fn
=
sample_inputs_perspective_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
perspective_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
perspective_image_pil
),
reference_inputs_fn
=
reference_inputs_perspective_mask
,
reference_inputs_fn
=
reference_inputs_perspective_mask
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -786,7 +786,7 @@ def sample_inputs_elastic_video():
...
@@ -786,7 +786,7 @@ def sample_inputs_elastic_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
elastic_image
_tensor
,
F
.
elastic_image
,
sample_inputs_fn
=
sample_inputs_elastic_image_tensor
,
sample_inputs_fn
=
sample_inputs_elastic_image_tensor
,
reference_inputs_fn
=
reference_inputs_elastic_image_tensor
,
reference_inputs_fn
=
reference_inputs_elastic_image_tensor
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
float32_vs_uint8
=
float32_vs_uint8_fill_adapter
,
...
@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video():
...
@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
center_crop_image
_tensor
,
F
.
center_crop_image
,
sample_inputs_fn
=
sample_inputs_center_crop_image_tensor
,
sample_inputs_fn
=
sample_inputs_center_crop_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
center_crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
center_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_center_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_center_crop_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
test_marks
=
[
test_marks
=
[
...
@@ -889,7 +889,7 @@ KERNEL_INFOS.extend(
...
@@ -889,7 +889,7 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
center_crop_mask
,
F
.
center_crop_mask
,
sample_inputs_fn
=
sample_inputs_center_crop_mask
,
sample_inputs_fn
=
sample_inputs_center_crop_mask
,
reference_fn
=
pil_reference_wrapper
(
F
.
center_crop_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
center_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_center_crop_mask
,
reference_inputs_fn
=
reference_inputs_center_crop_mask
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
test_marks
=
[
test_marks
=
[
...
@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video():
...
@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
gaussian_blur_image
_tensor
,
F
.
gaussian_blur_image
,
sample_inputs_fn
=
sample_inputs_gaussian_blur_image_tensor
,
sample_inputs_fn
=
sample_inputs_gaussian_blur_image_tensor
,
closeness_kwargs
=
cuda_vs_cpu_pixel_difference
(),
closeness_kwargs
=
cuda_vs_cpu_pixel_difference
(),
test_marks
=
[
test_marks
=
[
...
@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video():
...
@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
equalize_image
_tensor
,
F
.
equalize_image
,
kernel_name
=
"equalize_image_tensor"
,
kernel_name
=
"equalize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_equalize_image_tensor
,
sample_inputs_fn
=
sample_inputs_equalize_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
equalize_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
equalize_image_pil
),
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
reference_inputs_fn
=
reference_inputs_equalize_image_tensor
,
reference_inputs_fn
=
reference_inputs_equalize_image_tensor
,
),
),
...
@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video():
...
@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
invert_image
_tensor
,
F
.
invert_image
,
kernel_name
=
"invert_image_tensor"
,
kernel_name
=
"invert_image_tensor"
,
sample_inputs_fn
=
sample_inputs_invert_image_tensor
,
sample_inputs_fn
=
sample_inputs_invert_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
invert_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
invert_image_pil
),
reference_inputs_fn
=
reference_inputs_invert_image_tensor
,
reference_inputs_fn
=
reference_inputs_invert_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
),
),
...
@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video():
...
@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
posterize_image
_tensor
,
F
.
posterize_image
,
kernel_name
=
"posterize_image_tensor"
,
kernel_name
=
"posterize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_posterize_image_tensor
,
sample_inputs_fn
=
sample_inputs_posterize_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
posterize_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
posterize_image_pil
),
reference_inputs_fn
=
reference_inputs_posterize_image_tensor
,
reference_inputs_fn
=
reference_inputs_posterize_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
...
@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video():
...
@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
solarize_image
_tensor
,
F
.
solarize_image
,
kernel_name
=
"solarize_image_tensor"
,
kernel_name
=
"solarize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_solarize_image_tensor
,
sample_inputs_fn
=
sample_inputs_solarize_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
solarize_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
solarize_image_pil
),
reference_inputs_fn
=
reference_inputs_solarize_image_tensor
,
reference_inputs_fn
=
reference_inputs_solarize_image_tensor
,
float32_vs_uint8
=
uint8_to_float32_threshold_adapter
,
float32_vs_uint8
=
uint8_to_float32_threshold_adapter
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
...
@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video():
...
@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
autocontrast_image
_tensor
,
F
.
autocontrast_image
,
kernel_name
=
"autocontrast_image_tensor"
,
kernel_name
=
"autocontrast_image_tensor"
,
sample_inputs_fn
=
sample_inputs_autocontrast_image_tensor
,
sample_inputs_fn
=
sample_inputs_autocontrast_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
autocontrast_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
autocontrast_image_pil
),
reference_inputs_fn
=
reference_inputs_autocontrast_image_tensor
,
reference_inputs_fn
=
reference_inputs_autocontrast_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video():
...
@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
adjust_sharpness_image
_tensor
,
F
.
adjust_sharpness_image
,
kernel_name
=
"adjust_sharpness_image_tensor"
,
kernel_name
=
"adjust_sharpness_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_sharpness_image_tensor
,
sample_inputs_fn
=
sample_inputs_adjust_sharpness_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_sharpness_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_sharpness_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_sharpness_image_tensor
,
reference_inputs_fn
=
reference_inputs_adjust_sharpness_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(
2
),
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(
2
),
...
@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video():
...
@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
erase_image
_tensor
,
F
.
erase_image
,
kernel_name
=
"erase_image_tensor"
,
kernel_name
=
"erase_image_tensor"
,
sample_inputs_fn
=
sample_inputs_erase_image_tensor
,
sample_inputs_fn
=
sample_inputs_erase_image_tensor
,
),
),
...
@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video():
...
@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
adjust_contrast_image
_tensor
,
F
.
adjust_contrast_image
,
kernel_name
=
"adjust_contrast_image_tensor"
,
kernel_name
=
"adjust_contrast_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_contrast_image_tensor
,
sample_inputs_fn
=
sample_inputs_adjust_contrast_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_contrast_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_contrast_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_contrast_image_tensor
,
reference_inputs_fn
=
reference_inputs_adjust_contrast_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video():
...
@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
adjust_gamma_image
_tensor
,
F
.
adjust_gamma_image
,
kernel_name
=
"adjust_gamma_image_tensor"
,
kernel_name
=
"adjust_gamma_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_gamma_image_tensor
,
sample_inputs_fn
=
sample_inputs_adjust_gamma_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_gamma_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_gamma_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_gamma_image_tensor
,
reference_inputs_fn
=
reference_inputs_adjust_gamma_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video():
...
@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
adjust_hue_image
_tensor
,
F
.
adjust_hue_image
,
kernel_name
=
"adjust_hue_image_tensor"
,
kernel_name
=
"adjust_hue_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_hue_image_tensor
,
sample_inputs_fn
=
sample_inputs_adjust_hue_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_hue_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_hue_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_hue_image_tensor
,
reference_inputs_fn
=
reference_inputs_adjust_hue_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video():
...
@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
adjust_saturation_image
_tensor
,
F
.
adjust_saturation_image
,
kernel_name
=
"adjust_saturation_image_tensor"
,
kernel_name
=
"adjust_saturation_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_saturation_image_tensor
,
sample_inputs_fn
=
sample_inputs_adjust_saturation_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_saturation_image_pil
),
reference_fn
=
pil_reference_wrapper
(
F
.
_
adjust_saturation_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_saturation_image_tensor
,
reference_inputs_fn
=
reference_inputs_adjust_saturation_image_tensor
,
float32_vs_uint8
=
True
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
{
closeness_kwargs
=
{
...
@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
...
@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
def
wrapper
(
input_tensor
,
*
other_args
,
**
kwargs
):
def
wrapper
(
input_tensor
,
*
other_args
,
**
kwargs
):
output
=
pil_reference_wrapper
(
pil_kernel
)(
input_tensor
,
*
other_args
,
**
kwargs
)
output
=
pil_reference_wrapper
(
pil_kernel
)(
input_tensor
,
*
other_args
,
**
kwargs
)
return
type
(
output
)(
return
type
(
output
)(
F
.
to_dtype_image_tensor
(
F
.
to_image_tensor
(
output_pil
),
dtype
=
input_tensor
.
dtype
,
scale
=
True
)
F
.
to_dtype_image
(
F
.
to_image
(
output_pil
),
dtype
=
input_tensor
.
dtype
,
scale
=
True
)
for
output_pil
in
output
for
output_pil
in
output
)
)
return
wrapper
return
wrapper
...
@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [
...
@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
five_crop_image
_tensor
,
F
.
five_crop_image
,
sample_inputs_fn
=
sample_inputs_five_crop_image_tensor
,
sample_inputs_fn
=
sample_inputs_five_crop_image_tensor
,
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
five_crop_image_pil
),
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
_
five_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_five_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_five_crop_image_tensor
,
test_marks
=
_common_five_ten_crop_marks
,
test_marks
=
_common_five_ten_crop_marks
,
),
),
...
@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend(
...
@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend(
test_marks
=
_common_five_ten_crop_marks
,
test_marks
=
_common_five_ten_crop_marks
,
),
),
KernelInfo
(
KernelInfo
(
F
.
ten_crop_image
_tensor
,
F
.
ten_crop_image
,
sample_inputs_fn
=
sample_inputs_ten_crop_image_tensor
,
sample_inputs_fn
=
sample_inputs_ten_crop_image_tensor
,
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
ten_crop_image_pil
),
reference_fn
=
multi_crop_pil_reference_wrapper
(
F
.
_
ten_crop_image_pil
),
reference_inputs_fn
=
reference_inputs_ten_crop_image_tensor
,
reference_inputs_fn
=
reference_inputs_ten_crop_image_tensor
,
test_marks
=
_common_five_ten_crop_marks
,
test_marks
=
_common_five_ten_crop_marks
,
),
),
...
@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video():
...
@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video():
KERNEL_INFOS
.
extend
(
KERNEL_INFOS
.
extend
(
[
[
KernelInfo
(
KernelInfo
(
F
.
normalize_image
_tensor
,
F
.
normalize_image
,
kernel_name
=
"normalize_image_tensor"
,
kernel_name
=
"normalize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_normalize_image_tensor
,
sample_inputs_fn
=
sample_inputs_normalize_image_tensor
,
reference_fn
=
reference_normalize_image_tensor
,
reference_fn
=
reference_normalize_image_tensor
,
...
...
torchvision/prototype/transforms/_augment.py
View file @
ca012d39
...
@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform):
...
@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform):
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_simple_tensor
(
obj
):
if
isinstance
(
obj
,
datapoints
.
Image
)
or
is_simple_tensor
(
obj
):
images
.
append
(
obj
)
images
.
append
(
obj
)
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
images
.
append
(
F
.
to_image
_tensor
(
obj
))
images
.
append
(
F
.
to_image
(
obj
))
elif
isinstance
(
obj
,
datapoints
.
BoundingBoxes
):
elif
isinstance
(
obj
,
datapoints
.
BoundingBoxes
):
bboxes
.
append
(
obj
)
bboxes
.
append
(
obj
)
elif
isinstance
(
obj
,
datapoints
.
Mask
):
elif
isinstance
(
obj
,
datapoints
.
Mask
):
...
@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform):
...
@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform):
flat_sample
[
i
]
=
datapoints
.
wrap
(
output_images
[
c0
],
like
=
obj
)
flat_sample
[
i
]
=
datapoints
.
wrap
(
output_images
[
c0
],
like
=
obj
)
c0
+=
1
c0
+=
1
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
elif
isinstance
(
obj
,
PIL
.
Image
.
Image
):
flat_sample
[
i
]
=
F
.
to_image
_pil
(
output_images
[
c0
])
flat_sample
[
i
]
=
F
.
to_
pil_
image
(
output_images
[
c0
])
c0
+=
1
c0
+=
1
elif
is_simple_tensor
(
obj
):
elif
is_simple_tensor
(
obj
):
flat_sample
[
i
]
=
output_images
[
c0
]
flat_sample
[
i
]
=
output_images
[
c0
]
...
...
torchvision/transforms/v2/__init__.py
View file @
ca012d39
...
@@ -52,7 +52,7 @@ from ._misc import (
...
@@ -52,7 +52,7 @@ from ._misc import (
ToDtype
,
ToDtype
,
)
)
from
._temporal
import
UniformTemporalSubsample
from
._temporal
import
UniformTemporalSubsample
from
._type_conversion
import
PILToTensor
,
ToImage
PIL
,
ToImageTensor
,
ToPILImage
from
._type_conversion
import
PILToTensor
,
ToImage
,
ToPILImage
from
._deprecated
import
ToTensor
# usort: skip
from
._deprecated
import
ToTensor
# usort: skip
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
ca012d39
...
@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase):
...
@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase):
if
isinstance
(
orig_image_or_video
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
isinstance
(
orig_image_or_video
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
mix
=
datapoints
.
wrap
(
mix
,
like
=
orig_image_or_video
)
mix
=
datapoints
.
wrap
(
mix
,
like
=
orig_image_or_video
)
elif
isinstance
(
orig_image_or_video
,
PIL
.
Image
.
Image
):
elif
isinstance
(
orig_image_or_video
,
PIL
.
Image
.
Image
):
mix
=
F
.
to_image
_pil
(
mix
)
mix
=
F
.
to_
pil_
image
(
mix
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
mix
)
return
self
.
_unflatten_and_insert_image_or_video
(
flat_inputs_with_spec
,
mix
)
torchvision/transforms/v2/_type_conversion.py
View file @
ca012d39
...
@@ -26,7 +26,7 @@ class PILToTensor(Transform):
...
@@ -26,7 +26,7 @@ class PILToTensor(Transform):
return
F
.
pil_to_tensor
(
inpt
)
return
F
.
pil_to_tensor
(
inpt
)
class
ToImage
Tensor
(
Transform
):
class
ToImage
(
Transform
):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
; this does not scale values.
; this does not scale values.
...
@@ -40,10 +40,10 @@ class ToImageTensor(Transform):
...
@@ -40,10 +40,10 @@ class ToImageTensor(Transform):
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
datapoints
.
Image
:
)
->
datapoints
.
Image
:
return
F
.
to_image
_tensor
(
inpt
)
return
F
.
to_image
(
inpt
)
class
ToImage
PIL
(
Transform
):
class
To
PIL
Image
(
Transform
):
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
.. v2betastatus:: ToImagePIL transform
.. v2betastatus:: ToImagePIL transform
...
@@ -74,9 +74,4 @@ class ToImagePIL(Transform):
...
@@ -74,9 +74,4 @@ class ToImagePIL(Transform):
def
_transform
(
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
return
F
.
to_image_pil
(
inpt
,
mode
=
self
.
mode
)
return
F
.
to_pil_image
(
inpt
,
mode
=
self
.
mode
)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage
=
ToImagePIL
torchvision/transforms/v2/functional/__init__.py
View file @
ca012d39
...
@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip
...
@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip
from
._meta
import
(
from
._meta
import
(
clamp_bounding_boxes
,
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
convert_format_bounding_boxes
,
get_dimensions_image
_tensor
,
get_dimensions_image
,
get_dimensions_image_pil
,
_
get_dimensions_image_pil
,
get_dimensions_video
,
get_dimensions_video
,
get_dimensions
,
get_dimensions
,
get_num_frames_video
,
get_num_frames_video
,
get_num_frames
,
get_num_frames
,
get_image_num_channels
,
get_image_num_channels
,
get_num_channels_image
_tensor
,
get_num_channels_image
,
get_num_channels_image_pil
,
_
get_num_channels_image_pil
,
get_num_channels_video
,
get_num_channels_video
,
get_num_channels
,
get_num_channels
,
get_size_bounding_boxes
,
get_size_bounding_boxes
,
get_size_image
_tensor
,
get_size_image
,
get_size_image_pil
,
_
get_size_image_pil
,
get_size_mask
,
get_size_mask
,
get_size_video
,
get_size_video
,
get_size
,
get_size
,
)
# usort: skip
)
# usort: skip
from
._augment
import
erase
,
erase_image_pil
,
erase
_image_tensor
,
erase_video
from
._augment
import
_
erase_image_pil
,
erase
,
erase_image
,
erase_video
from
._color
import
(
from
._color
import
(
_adjust_brightness_image_pil
,
_adjust_contrast_image_pil
,
_adjust_gamma_image_pil
,
_adjust_hue_image_pil
,
_adjust_saturation_image_pil
,
_adjust_sharpness_image_pil
,
_autocontrast_image_pil
,
_equalize_image_pil
,
_invert_image_pil
,
_permute_channels_image_pil
,
_posterize_image_pil
,
_rgb_to_grayscale_image_pil
,
_solarize_image_pil
,
adjust_brightness
,
adjust_brightness
,
adjust_brightness_image_pil
,
adjust_brightness_image
,
adjust_brightness_image_tensor
,
adjust_brightness_video
,
adjust_brightness_video
,
adjust_contrast
,
adjust_contrast
,
adjust_contrast_image_pil
,
adjust_contrast_image
,
adjust_contrast_image_tensor
,
adjust_contrast_video
,
adjust_contrast_video
,
adjust_gamma
,
adjust_gamma
,
adjust_gamma_image_pil
,
adjust_gamma_image
,
adjust_gamma_image_tensor
,
adjust_gamma_video
,
adjust_gamma_video
,
adjust_hue
,
adjust_hue
,
adjust_hue_image_pil
,
adjust_hue_image
,
adjust_hue_image_tensor
,
adjust_hue_video
,
adjust_hue_video
,
adjust_saturation
,
adjust_saturation
,
adjust_saturation_image_pil
,
adjust_saturation_image
,
adjust_saturation_image_tensor
,
adjust_saturation_video
,
adjust_saturation_video
,
adjust_sharpness
,
adjust_sharpness
,
adjust_sharpness_image_pil
,
adjust_sharpness_image
,
adjust_sharpness_image_tensor
,
adjust_sharpness_video
,
adjust_sharpness_video
,
autocontrast
,
autocontrast
,
autocontrast_image_pil
,
autocontrast_image
,
autocontrast_image_tensor
,
autocontrast_video
,
autocontrast_video
,
equalize
,
equalize
,
equalize_image_pil
,
equalize_image
,
equalize_image_tensor
,
equalize_video
,
equalize_video
,
invert
,
invert
,
invert_image_pil
,
invert_image
,
invert_image_tensor
,
invert_video
,
invert_video
,
permute_channels
,
permute_channels
,
permute_channels_image_pil
,
permute_channels_image
,
permute_channels_image_tensor
,
permute_channels_video
,
permute_channels_video
,
posterize
,
posterize
,
posterize_image_pil
,
posterize_image
,
posterize_image_tensor
,
posterize_video
,
posterize_video
,
rgb_to_grayscale
,
rgb_to_grayscale
,
rgb_to_grayscale_image_pil
,
rgb_to_grayscale_image
,
rgb_to_grayscale_image_tensor
,
solarize
,
solarize
,
solarize_image_pil
,
solarize_image
,
solarize_image_tensor
,
solarize_video
,
solarize_video
,
to_grayscale
,
to_grayscale
,
)
)
from
._geometry
import
(
from
._geometry
import
(
_affine_image_pil
,
_center_crop_image_pil
,
_crop_image_pil
,
_elastic_image_pil
,
_five_crop_image_pil
,
_horizontal_flip_image_pil
,
_pad_image_pil
,
_perspective_image_pil
,
_resize_image_pil
,
_resized_crop_image_pil
,
_rotate_image_pil
,
_ten_crop_image_pil
,
_vertical_flip_image_pil
,
affine
,
affine
,
affine_bounding_boxes
,
affine_bounding_boxes
,
affine_image_pil
,
affine_image
,
affine_image_tensor
,
affine_mask
,
affine_mask
,
affine_video
,
affine_video
,
center_crop
,
center_crop
,
center_crop_bounding_boxes
,
center_crop_bounding_boxes
,
center_crop_image_pil
,
center_crop_image
,
center_crop_image_tensor
,
center_crop_mask
,
center_crop_mask
,
center_crop_video
,
center_crop_video
,
crop
,
crop
,
crop_bounding_boxes
,
crop_bounding_boxes
,
crop_image_pil
,
crop_image
,
crop_image_tensor
,
crop_mask
,
crop_mask
,
crop_video
,
crop_video
,
elastic
,
elastic
,
elastic_bounding_boxes
,
elastic_bounding_boxes
,
elastic_image_pil
,
elastic_image
,
elastic_image_tensor
,
elastic_mask
,
elastic_mask
,
elastic_transform
,
elastic_transform
,
elastic_video
,
elastic_video
,
five_crop
,
five_crop
,
five_crop_image_pil
,
five_crop_image
,
five_crop_image_tensor
,
five_crop_video
,
five_crop_video
,
hflip
,
# TODO: Consider moving all pure alias definitions at the bottom of the file
hflip
,
# TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip
,
horizontal_flip
,
horizontal_flip_bounding_boxes
,
horizontal_flip_bounding_boxes
,
horizontal_flip_image_pil
,
horizontal_flip_image
,
horizontal_flip_image_tensor
,
horizontal_flip_mask
,
horizontal_flip_mask
,
horizontal_flip_video
,
horizontal_flip_video
,
pad
,
pad
,
pad_bounding_boxes
,
pad_bounding_boxes
,
pad_image_pil
,
pad_image
,
pad_image_tensor
,
pad_mask
,
pad_mask
,
pad_video
,
pad_video
,
perspective
,
perspective
,
perspective_bounding_boxes
,
perspective_bounding_boxes
,
perspective_image_pil
,
perspective_image
,
perspective_image_tensor
,
perspective_mask
,
perspective_mask
,
perspective_video
,
perspective_video
,
resize
,
resize
,
resize_bounding_boxes
,
resize_bounding_boxes
,
resize_image_pil
,
resize_image
,
resize_image_tensor
,
resize_mask
,
resize_mask
,
resize_video
,
resize_video
,
resized_crop
,
resized_crop
,
resized_crop_bounding_boxes
,
resized_crop_bounding_boxes
,
resized_crop_image_pil
,
resized_crop_image
,
resized_crop_image_tensor
,
resized_crop_mask
,
resized_crop_mask
,
resized_crop_video
,
resized_crop_video
,
rotate
,
rotate
,
rotate_bounding_boxes
,
rotate_bounding_boxes
,
rotate_image_pil
,
rotate_image
,
rotate_image_tensor
,
rotate_mask
,
rotate_mask
,
rotate_video
,
rotate_video
,
ten_crop
,
ten_crop
,
ten_crop_image_pil
,
ten_crop_image
,
ten_crop_image_tensor
,
ten_crop_video
,
ten_crop_video
,
vertical_flip
,
vertical_flip
,
vertical_flip_bounding_boxes
,
vertical_flip_bounding_boxes
,
vertical_flip_image_pil
,
vertical_flip_image
,
vertical_flip_image_tensor
,
vertical_flip_mask
,
vertical_flip_mask
,
vertical_flip_video
,
vertical_flip_video
,
vflip
,
vflip
,
)
)
from
._misc
import
(
from
._misc
import
(
_gaussian_blur_image_pil
,
convert_image_dtype
,
convert_image_dtype
,
gaussian_blur
,
gaussian_blur
,
gaussian_blur_image_pil
,
gaussian_blur_image
,
gaussian_blur_image_tensor
,
gaussian_blur_video
,
gaussian_blur_video
,
normalize
,
normalize
,
normalize_image
_tensor
,
normalize_image
,
normalize_video
,
normalize_video
,
to_dtype
,
to_dtype
,
to_dtype_image
_tensor
,
to_dtype_image
,
to_dtype_video
,
to_dtype_video
,
)
)
from
._temporal
import
uniform_temporal_subsample
,
uniform_temporal_subsample_video
from
._temporal
import
uniform_temporal_subsample
,
uniform_temporal_subsample_video
from
._type_conversion
import
pil_to_tensor
,
to_image
_pil
,
to_image_tensor
,
to_pil_image
from
._type_conversion
import
pil_to_tensor
,
to_image
,
to_pil_image
from
._deprecated
import
get_image_size
,
to_tensor
# usort: skip
from
._deprecated
import
get_image_size
,
to_tensor
# usort: skip
torchvision/transforms/v2/functional/_augment.py
View file @
ca012d39
...
@@ -18,7 +18,7 @@ def erase(
...
@@ -18,7 +18,7 @@ def erase(
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
erase_image
_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
_log_api_usage_once
(
erase
)
_log_api_usage_once
(
erase
)
...
@@ -28,7 +28,7 @@ def erase(
...
@@ -28,7 +28,7 @@ def erase(
@
_register_kernel_internal
(
erase
,
torch
.
Tensor
)
@
_register_kernel_internal
(
erase
,
torch
.
Tensor
)
@
_register_kernel_internal
(
erase
,
datapoints
.
Image
)
@
_register_kernel_internal
(
erase
,
datapoints
.
Image
)
def
erase_image
_tensor
(
def
erase_image
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
inplace
:
if
not
inplace
:
...
@@ -39,11 +39,11 @@ def erase_image_tensor(
...
@@ -39,11 +39,11 @@ def erase_image_tensor(
@
_register_kernel_internal
(
erase
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
erase
,
PIL
.
Image
.
Image
)
def
erase_image_pil
(
def
_
erase_image_pil
(
image
:
PIL
.
Image
.
Image
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
image
:
PIL
.
Image
.
Image
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
PIL
.
Image
.
Image
:
)
->
PIL
.
Image
.
Image
:
t_img
=
pil_to_tensor
(
image
)
t_img
=
pil_to_tensor
(
image
)
output
=
erase_image
_tensor
(
t_img
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
output
=
erase_image
(
t_img
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
...
@@ -51,4 +51,4 @@ def erase_image_pil(
...
@@ -51,4 +51,4 @@ def erase_image_pil(
def
erase_video
(
def
erase_video
(
video
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
video
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
erase_image
_tensor
(
video
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
erase_image
(
video
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
torchvision/transforms/v2/functional/_color.py
View file @
ca012d39
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment