Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
db310636
Unverified
Commit
db310636
authored
Aug 10, 2023
by
Philip Meier
Committed by
GitHub
Aug 10, 2023
Browse files
move passthrough for unknown types from dispatchers to transforms (#7804)
parent
87681314
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
150 additions
and
731 deletions
+150
-731
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+0
-63
test/test_transforms_v2.py
test/test_transforms_v2.py
+2
-360
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+0
-62
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+16
-40
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-2
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+10
-2
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+21
-50
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+57
-24
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+5
-8
torchvision/transforms/v2/_temporal.py
torchvision/transforms/v2/_temporal.py
+1
-1
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+6
-0
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+1
-2
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+2
-15
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+9
-17
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+1
-4
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+1
-5
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+1
-5
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+14
-71
No files found.
test/test_prototype_transforms.py
View file @
db310636
import
itertools
import
re
import
PIL.Image
...
...
@@ -19,7 +17,6 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
_convert_fill_arg
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_image_pil
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
...
...
@@ -187,66 +184,6 @@ class TestFixedSizeCrop:
assert
params
[
"needs_pad"
]
assert
any
(
pad
>
0
for
pad
in
params
[
"padding"
])
@
pytest
.
mark
.
parametrize
(
"needs"
,
list
(
itertools
.
product
((
False
,
True
),
repeat
=
2
)))
def
test__transform
(
self
,
mocker
,
needs
):
fill_sentinel
=
12
padding_mode_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
),
fill
=
fill_sentinel
,
padding_mode
=
padding_mode_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
needs_crop
,
needs_pad
=
needs
top_sentinel
=
mocker
.
MagicMock
()
left_sentinel
=
mocker
.
MagicMock
()
height_sentinel
=
mocker
.
MagicMock
()
width_sentinel
=
mocker
.
MagicMock
()
is_valid
=
mocker
.
MagicMock
()
if
needs_crop
else
None
padding_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params"
,
return_value
=
dict
(
needs_crop
=
needs_crop
,
top
=
top_sentinel
,
left
=
left_sentinel
,
height
=
height_sentinel
,
width
=
width_sentinel
,
is_valid
=
is_valid
,
padding
=
padding_sentinel
,
needs_pad
=
needs_pad
,
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock_crop
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.crop"
)
mock_pad
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.pad"
)
transform
(
inpt_sentinel
)
if
needs_crop
:
mock_crop
.
assert_called_once_with
(
inpt_sentinel
,
top
=
top_sentinel
,
left
=
left_sentinel
,
height
=
height_sentinel
,
width
=
width_sentinel
,
)
else
:
mock_crop
.
assert_not_called
()
if
needs_pad
:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad
.
assert_called_once
()
args
,
kwargs
=
mock_pad
.
call_args
if
not
needs_crop
:
assert
args
[
0
]
is
inpt_sentinel
assert
args
[
1
]
is
padding_sentinel
fill_sentinel
=
_convert_fill_arg
(
fill_sentinel
)
assert
kwargs
==
dict
(
fill
=
fill_sentinel
,
padding_mode
=
padding_mode_sentinel
)
else
:
mock_pad
.
assert_not_called
()
def
test__transform_culling
(
self
,
mocker
):
batch_size
=
10
canvas_size
=
(
10
,
10
)
...
...
test/test_transforms_v2.py
View file @
db310636
...
...
@@ -27,7 +27,7 @@ from common_utils import (
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoints
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
InterpolationMode
,
to_pil_image
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
,
query_chw
...
...
@@ -419,46 +419,6 @@ class TestPad:
with
pytest
.
raises
(
ValueError
,
match
=
"Padding mode should be either"
):
transforms
.
Pad
(
12
,
padding_mode
=
"abc"
)
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
1
,
(
1
,
2
),
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
0
,
[
1
,
2
,
3
],
(
2
,
3
,
4
)])
@
pytest
.
mark
.
parametrize
(
"padding_mode"
,
[
"constant"
,
"edge"
])
def
test__transform
(
self
,
padding
,
fill
,
padding_mode
,
mocker
):
transform
=
transforms
.
Pad
(
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
_
=
transform
(
inpt
)
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
if
isinstance
(
padding
,
tuple
):
padding
=
list
(
padding
)
fn
.
assert_called_once_with
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
12
,
{
datapoints
.
Image
:
12
,
datapoints
.
Mask
:
34
}])
def
test__transform_image_mask
(
self
,
fill
,
mocker
):
transform
=
transforms
.
Pad
(
1
,
fill
=
fill
,
padding_mode
=
"constant"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
32
,
32
))
mask
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
5
,
size
=
(
32
,
32
)))
inpt
=
[
image
,
mask
]
_
=
transform
(
inpt
)
if
isinstance
(
fill
,
int
):
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
calls
=
[
mocker
.
call
(
image
,
padding
=
1
,
fill
=
fill
,
padding_mode
=
"constant"
),
mocker
.
call
(
mask
,
padding
=
1
,
fill
=
fill
,
padding_mode
=
"constant"
),
]
else
:
fill_img
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
image
)])
fill_mask
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
mask
)])
calls
=
[
mocker
.
call
(
image
,
padding
=
1
,
fill
=
fill_img
,
padding_mode
=
"constant"
),
mocker
.
call
(
mask
,
padding
=
1
,
fill
=
fill_mask
,
padding_mode
=
"constant"
),
]
fn
.
assert_has_calls
(
calls
)
class
TestRandomZoomOut
:
def
test_assertions
(
self
):
...
...
@@ -487,56 +447,6 @@ class TestRandomZoomOut:
assert
0
<=
params
[
"padding"
][
2
]
<=
(
side_range
[
1
]
-
1
)
*
w
assert
0
<=
params
[
"padding"
][
3
]
<=
(
side_range
[
1
]
-
1
)
*
h
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
0
,
[
1
,
2
,
3
],
(
2
,
3
,
4
)])
@
pytest
.
mark
.
parametrize
(
"side_range"
,
[(
1.0
,
4.0
),
[
2.0
,
5.0
]])
def
test__transform
(
self
,
fill
,
side_range
,
mocker
):
inpt
=
make_image
((
24
,
32
))
transform
=
transforms
.
RandomZoomOut
(
fill
=
fill
,
side_range
=
side_range
,
p
=
1
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
torch
.
rand
(
1
)
# random apply changes random state
params
=
transform
.
_get_params
([
inpt
])
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
fn
.
assert_called_once_with
(
inpt
,
**
params
,
fill
=
fill
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
12
,
{
datapoints
.
Image
:
12
,
datapoints
.
Mask
:
34
}])
def
test__transform_image_mask
(
self
,
fill
,
mocker
):
transform
=
transforms
.
RandomZoomOut
(
fill
=
fill
,
p
=
1.0
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
32
,
32
))
mask
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
5
,
size
=
(
32
,
32
)))
inpt
=
[
image
,
mask
]
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
torch
.
rand
(
1
)
# random apply changes random state
params
=
transform
.
_get_params
(
inpt
)
if
isinstance
(
fill
,
int
):
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
calls
=
[
mocker
.
call
(
image
,
**
params
,
fill
=
fill
),
mocker
.
call
(
mask
,
**
params
,
fill
=
fill
),
]
else
:
fill_img
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
image
)])
fill_mask
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
mask
)])
calls
=
[
mocker
.
call
(
image
,
**
params
,
fill
=
fill_img
),
mocker
.
call
(
mask
,
**
params
,
fill
=
fill_mask
),
]
fn
.
assert_has_calls
(
calls
)
class
TestRandomCrop
:
def
test_assertions
(
self
):
...
...
@@ -599,51 +509,6 @@ class TestRandomCrop:
assert
params
[
"needs_pad"
]
is
any
(
padding
)
assert
params
[
"padding"
]
==
padding
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
None
,
1
,
[
2
,
3
],
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
"pad_if_needed"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"padding_mode"
,
[
"constant"
,
"edge"
])
def
test__transform
(
self
,
padding
,
pad_if_needed
,
fill
,
padding_mode
,
mocker
):
output_size
=
[
10
,
12
]
transform
=
transforms
.
RandomCrop
(
output_size
,
padding
=
padding
,
pad_if_needed
=
pad_if_needed
,
fill
=
fill
,
padding_mode
=
padding_mode
)
h
,
w
=
size
=
(
32
,
32
)
inpt
=
make_image
(
size
)
if
isinstance
(
padding
,
int
):
new_size
=
(
h
+
padding
,
w
+
padding
)
elif
isinstance
(
padding
,
list
):
new_size
=
(
h
+
sum
(
padding
[
0
::
2
]),
w
+
sum
(
padding
[
1
::
2
]))
else
:
new_size
=
size
expected
=
make_image
(
new_size
)
_
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
,
return_value
=
expected
)
fn_crop
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.crop"
)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
params
=
transform
.
_get_params
([
inpt
])
if
padding
is
None
and
not
pad_if_needed
:
fn_crop
.
assert_called_once_with
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
output_size
[
0
],
width
=
output_size
[
1
]
)
elif
not
pad_if_needed
:
fn_crop
.
assert_called_once_with
(
expected
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
output_size
[
0
],
width
=
output_size
[
1
]
)
elif
padding
is
None
:
# vfdev-5: I do not know how to mock and test this case
pass
else
:
# vfdev-5: I do not know how to mock and test this case
pass
class
TestGaussianBlur
:
def
test_assertions
(
self
):
...
...
@@ -675,62 +540,6 @@ class TestGaussianBlur:
assert
sigma
[
0
]
<=
params
[
"sigma"
][
0
]
<=
sigma
[
1
]
assert
sigma
[
0
]
<=
params
[
"sigma"
][
1
]
<=
sigma
[
1
]
@
pytest
.
mark
.
parametrize
(
"kernel_size"
,
[
3
,
[
3
,
5
],
(
5
,
3
)])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[
2.0
,
[
2.0
,
3.0
]])
def
test__transform
(
self
,
kernel_size
,
sigma
,
mocker
):
transform
=
transforms
.
GaussianBlur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
if
isinstance
(
kernel_size
,
(
tuple
,
list
)):
assert
transform
.
kernel_size
==
kernel_size
else
:
kernel_size
=
(
kernel_size
,
kernel_size
)
assert
transform
.
kernel_size
==
kernel_size
if
isinstance
(
sigma
,
(
tuple
,
list
)):
assert
transform
.
sigma
==
sigma
else
:
assert
transform
.
sigma
==
[
sigma
,
sigma
]
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.gaussian_blur"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
inpt
.
num_channels
=
3
inpt
.
canvas_size
=
(
24
,
32
)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
params
=
transform
.
_get_params
([
inpt
])
fn
.
assert_called_once_with
(
inpt
,
kernel_size
,
**
params
)
class
TestRandomColorOp
:
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"transform_cls, func_op_name, kwargs"
,
[
(
transforms
.
RandomEqualize
,
"equalize"
,
{}),
(
transforms
.
RandomInvert
,
"invert"
,
{}),
(
transforms
.
RandomAutocontrast
,
"autocontrast"
,
{}),
(
transforms
.
RandomPosterize
,
"posterize"
,
{
"bits"
:
4
}),
(
transforms
.
RandomSolarize
,
"solarize"
,
{
"threshold"
:
0.5
}),
(
transforms
.
RandomAdjustSharpness
,
"adjust_sharpness"
,
{
"sharpness_factor"
:
0.5
}),
],
)
def
test__transform
(
self
,
p
,
transform_cls
,
func_op_name
,
kwargs
,
mocker
):
transform
=
transform_cls
(
p
=
p
,
**
kwargs
)
fn
=
mocker
.
patch
(
f
"torchvision.transforms.v2.functional.
{
func_op_name
}
"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
_
=
transform
(
inpt
)
if
p
>
0.0
:
fn
.
assert_called_once_with
(
inpt
,
**
kwargs
)
else
:
assert
fn
.
call_count
==
0
class
TestRandomPerspective
:
def
test_assertions
(
self
):
...
...
@@ -751,28 +560,6 @@ class TestRandomPerspective:
assert
"coefficients"
in
params
assert
len
(
params
[
"coefficients"
])
==
8
@
pytest
.
mark
.
parametrize
(
"distortion_scale"
,
[
0.1
,
0.7
])
def
test__transform
(
self
,
distortion_scale
,
mocker
):
interpolation
=
InterpolationMode
.
BILINEAR
fill
=
12
transform
=
transforms
.
RandomPerspective
(
distortion_scale
,
fill
=
fill
,
interpolation
=
interpolation
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.perspective"
)
inpt
=
make_image
((
24
,
32
))
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
torch
.
rand
(
1
)
# random apply changes random state
params
=
transform
.
_get_params
([
inpt
])
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
fn
.
assert_called_once_with
(
inpt
,
None
,
None
,
**
params
,
fill
=
fill
,
interpolation
=
interpolation
)
class
TestElasticTransform
:
def
test_assertions
(
self
):
...
...
@@ -813,35 +600,6 @@ class TestElasticTransform:
assert
(
-
alpha
/
w
<=
displacement
[
0
,
...,
0
]).
all
()
and
(
displacement
[
0
,
...,
0
]
<=
alpha
/
w
).
all
()
assert
(
-
alpha
/
h
<=
displacement
[
0
,
...,
1
]).
all
()
and
(
displacement
[
0
,
...,
1
]
<=
alpha
/
h
).
all
()
@
pytest
.
mark
.
parametrize
(
"alpha"
,
[
5.0
,
[
5.0
,
10.0
]])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[
2.0
,
[
2.0
,
5.0
]])
def
test__transform
(
self
,
alpha
,
sigma
,
mocker
):
interpolation
=
InterpolationMode
.
BILINEAR
fill
=
12
transform
=
transforms
.
ElasticTransform
(
alpha
,
sigma
=
sigma
,
fill
=
fill
,
interpolation
=
interpolation
)
if
isinstance
(
alpha
,
float
):
assert
transform
.
alpha
==
[
alpha
,
alpha
]
else
:
assert
transform
.
alpha
==
alpha
if
isinstance
(
sigma
,
float
):
assert
transform
.
sigma
==
[
sigma
,
sigma
]
else
:
assert
transform
.
sigma
==
sigma
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.elastic"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
inpt
.
num_channels
=
3
inpt
.
canvas_size
=
(
24
,
32
)
# Let's mock transform._get_params to control the output:
transform
.
_get_params
=
mocker
.
MagicMock
()
_
=
transform
(
inpt
)
params
=
transform
.
_get_params
([
inpt
])
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
fn
.
assert_called_once_with
(
inpt
,
**
params
,
fill
=
fill
,
interpolation
=
interpolation
)
class
TestRandomErasing
:
def
test_assertions
(
self
):
...
...
@@ -889,40 +647,6 @@ class TestRandomErasing:
assert
0
<=
i
<=
height
-
h
assert
0
<=
j
<=
width
-
w
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
1
])
def
test__transform
(
self
,
mocker
,
p
):
transform
=
transforms
.
RandomErasing
(
p
=
p
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
i_sentinel
=
mocker
.
MagicMock
()
j_sentinel
=
mocker
.
MagicMock
()
h_sentinel
=
mocker
.
MagicMock
()
w_sentinel
=
mocker
.
MagicMock
()
v_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._augment.RandomErasing._get_params"
,
return_value
=
dict
(
i
=
i_sentinel
,
j
=
j_sentinel
,
h
=
h_sentinel
,
w
=
w_sentinel
,
v
=
v_sentinel
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._augment.F.erase"
)
output
=
transform
(
inpt_sentinel
)
if
p
:
mock
.
assert_called_once_with
(
inpt_sentinel
,
i
=
i_sentinel
,
j
=
j_sentinel
,
h
=
h_sentinel
,
w
=
w_sentinel
,
v
=
v_sentinel
,
inplace
=
transform
.
inplace
,
)
else
:
mock
.
assert_not_called
()
assert
output
is
inpt_sentinel
class
TestTransform
:
@
pytest
.
mark
.
parametrize
(
...
...
@@ -1111,23 +835,12 @@ class TestRandomIoUCrop:
sample
=
[
image
,
bboxes
,
masks
]
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.crop"
,
side_effect
=
lambda
x
,
**
params
:
x
)
is_within_crop_area
=
torch
.
tensor
([
0
,
1
,
0
,
1
,
0
,
1
],
dtype
=
torch
.
bool
)
params
=
dict
(
top
=
1
,
left
=
2
,
height
=
12
,
width
=
12
,
is_within_crop_area
=
is_within_crop_area
)
transform
.
_get_params
=
mocker
.
MagicMock
(
return_value
=
params
)
output
=
transform
(
sample
)
assert
fn
.
call_count
==
3
expected_calls
=
[
mocker
.
call
(
image
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]),
mocker
.
call
(
bboxes
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]),
mocker
.
call
(
masks
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]),
]
fn
.
assert_has_calls
(
expected_calls
)
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBoxes
)
...
...
@@ -1164,29 +877,6 @@ class TestScaleJitter:
assert
int
(
canvas_size
[
0
]
*
r_min
)
<=
height
<=
int
(
canvas_size
[
0
]
*
r_max
)
assert
int
(
canvas_size
[
1
]
*
r_min
)
<=
width
<=
int
(
canvas_size
[
1
]
*
r_max
)
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
ScaleJitter
(
target_size
=
(
16
,
12
),
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
size_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._geometry.ScaleJitter._get_params"
,
return_value
=
dict
(
size
=
size_sentinel
)
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
transform
(
inpt_sentinel
)
mock
.
assert_called_once_with
(
inpt_sentinel
,
size
=
size_sentinel
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
class
TestRandomShortestSize
:
@
pytest
.
mark
.
parametrize
(
"min_size,max_size"
,
[([
5
,
9
],
20
),
([
5
,
9
],
None
)])
...
...
@@ -1211,30 +901,6 @@ class TestRandomShortestSize:
else
:
assert
shorter
in
min_size
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
RandomShortestSize
(
min_size
=
[
3
,
5
,
7
],
max_size
=
12
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
size_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._geometry.RandomShortestSize._get_params"
,
return_value
=
dict
(
size
=
size_sentinel
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
transform
(
inpt_sentinel
)
mock
.
assert_called_once_with
(
inpt_sentinel
,
size
=
size_sentinel
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
class
TestLinearTransformation
:
def
test_assertions
(
self
):
...
...
@@ -1260,7 +926,7 @@ class TestLinearTransformation:
transform
=
transforms
.
LinearTransformation
(
m
,
v
)
if
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
with
pytest
.
raises
(
TypeError
,
match
=
"
LinearTransformation
does not
work on
PIL
I
mages"
):
with
pytest
.
raises
(
TypeError
,
match
=
"does not
support
PIL
i
mages"
):
transform
(
inpt
)
else
:
output
=
transform
(
inpt
)
...
...
@@ -1284,30 +950,6 @@ class TestRandomResize:
assert
min_size
<=
size
<
max_size
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
RandomResize
(
min_size
=-
1
,
max_size
=-
1
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
size_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._geometry.RandomResize._get_params"
,
return_value
=
dict
(
size
=
size_sentinel
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock_resize
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
transform
(
inpt_sentinel
)
mock_resize
.
assert_called_with
(
inpt_sentinel
,
size_sentinel
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
class
TestUniformTemporalSubsample
:
@
pytest
.
mark
.
parametrize
(
...
...
test/test_transforms_v2_consistency.py
View file @
db310636
...
...
@@ -1259,68 +1259,6 @@ class TestRefSegTransforms:
def
test_common
(
self
,
t_ref
,
t
,
data_kwargs
):
self
.
check
(
t
,
t_ref
,
data_kwargs
)
def
check_resize
(
self
,
mocker
,
t_ref
,
t
):
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
mock_ref
=
mocker
.
patch
(
"torchvision.transforms.functional.resize"
)
for
dp
,
dp_ref
in
self
.
make_datapoints
():
mock
.
reset_mock
()
mock_ref
.
reset_mock
()
self
.
set_seed
()
t
(
dp
)
assert
mock
.
call_count
==
2
assert
all
(
actual
is
expected
for
actual
,
expected
in
zip
([
call_args
[
0
][
0
]
for
call_args
in
mock
.
call_args_list
],
dp
)
)
self
.
set_seed
()
t_ref
(
*
dp_ref
)
assert
mock_ref
.
call_count
==
2
assert
all
(
actual
is
expected
for
actual
,
expected
in
zip
([
call_args
[
0
][
0
]
for
call_args
in
mock_ref
.
call_args_list
],
dp_ref
)
)
for
args_kwargs
,
args_kwargs_ref
in
zip
(
mock
.
call_args_list
,
mock_ref
.
call_args_list
):
assert
args_kwargs
[
0
][
1
]
==
[
args_kwargs_ref
[
0
][
1
]]
def
test_random_resize_train
(
self
,
mocker
):
base_size
=
520
min_size
=
base_size
//
2
max_size
=
base_size
*
2
randint
=
torch
.
randint
def
patched_randint
(
a
,
b
,
*
other_args
,
**
kwargs
):
if
kwargs
or
len
(
other_args
)
>
1
or
other_args
[
0
]
!=
():
return
randint
(
a
,
b
,
*
other_args
,
**
kwargs
)
return
random
.
randint
(
a
,
b
)
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t
=
v2_transforms
.
RandomResize
(
min_size
=
min_size
,
max_size
=
max_size
,
antialias
=
True
)
mocker
.
patch
(
"torchvision.transforms.v2._geometry.torch.randint"
,
new
=
patched_randint
,
)
t_ref
=
seg_transforms
.
RandomResize
(
min_size
=
min_size
,
max_size
=
max_size
)
self
.
check_resize
(
mocker
,
t_ref
,
t
)
def
test_random_resize_eval
(
self
,
mocker
):
torch
.
manual_seed
(
0
)
base_size
=
520
t
=
v2_transforms
.
Resize
(
size
=
base_size
,
antialias
=
True
)
t_ref
=
seg_transforms
.
RandomResize
(
min_size
=
base_size
,
max_size
=
base_size
)
self
.
check_resize
(
mocker
,
t_ref
,
t
)
@
pytest
.
mark
.
parametrize
(
(
"legacy_dispatcher"
,
"name_only_params"
),
...
...
test/test_transforms_v2_refactored.py
View file @
db310636
...
...
@@ -39,7 +39,7 @@ from torchvision import datapoints
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional
import
pil_modes_mapping
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._utils
import
_get_kernel
,
_KERNEL_REGISTRY
,
_noop
,
_register_kernel_internal
from
torchvision.transforms.v2.functional._utils
import
_get_kernel
,
_register_kernel_internal
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -376,35 +376,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
return
torch
.
stack
([
transform
(
b
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()]).
reshape
(
bounding_boxes
.
shape
)
@
pytest
.
mark
.
parametrize
(
(
"dispatcher"
,
"registered_input_types"
),
[(
dispatcher
,
set
(
registry
.
keys
()))
for
dispatcher
,
registry
in
_KERNEL_REGISTRY
.
items
()],
)
def
test_exhaustive_kernel_registration
(
dispatcher
,
registered_input_types
):
missing
=
{
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
,
}
-
registered_input_types
if
missing
:
names
=
sorted
(
str
(
t
)
for
t
in
missing
)
raise
AssertionError
(
"
\n
"
.
join
(
[
f
"The dispatcher '
{
dispatcher
.
__name__
}
' has no kernel registered for"
,
""
,
*
[
f
"-
{
name
}
"
for
name
in
names
],
""
,
f
"If available, register the kernels with @_register_kernel_internal(
{
dispatcher
.
__name__
}
, ...)."
,
f
"If not, register explicit no-ops with @_register_explicit_noop(
{
', '
.
join
(
names
)
}
)"
,
]
)
)
class
TestResize
:
INPUT_SIZE
=
(
17
,
11
)
OUTPUT_SIZES
=
[
17
,
[
17
],
(
17
,),
[
12
,
13
],
(
12
,
13
)]
...
...
@@ -2128,9 +2099,20 @@ class TestRegisterKernel:
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
F
.
register_kernel
(
F
.
resize
,
object
)
with
pytest
.
raises
(
ValueError
,
match
=
"
already has a kernel registered for type
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
cannot be registered for the builtin datapoint classes
"
):
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image_tensor
)
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
def
resize_custom_datapoint
():
pass
F
.
register_kernel
(
F
.
resize
,
CustomDatapoint
)(
resize_custom_datapoint
)
with
pytest
.
raises
(
ValueError
,
match
=
"already has a kernel registered for type"
):
F
.
register_kernel
(
F
.
resize
,
CustomDatapoint
)(
resize_custom_datapoint
)
class
TestGetKernel
:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
...
...
@@ -2152,13 +2134,7 @@ class TestGetKernel:
pass
for
input_type
in
[
str
,
int
,
object
,
MyTensor
,
MyPILImage
]:
with
pytest
.
raises
(
TypeError
,
match
=
(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
_get_kernel
(
F
.
resize
,
input_type
)
def
test_exact_match
(
self
):
...
...
@@ -2211,8 +2187,8 @@ class TestGetKernel:
class
MyDatapoint
(
datapoints
.
Datapoint
):
pass
# Note that this will be an error in the future
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
_noop
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
_get_kernel
(
F
.
resize
,
MyDatapoint
)
def
resize_my_datapoint
():
pass
...
...
torchvision/prototype/transforms/_geometry.py
View file @
db310636
...
...
@@ -101,7 +101,8 @@ class FixedSizeCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_crop"
]:
inpt
=
F
.
crop
(
inpt
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
...
...
@@ -120,6 +121,6 @@ class FixedSizeCrop(Transform):
if
params
[
"needs_pad"
]:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
self
.
_call_kernel
(
F
.
pad
,
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
torchvision/transforms/v2/_augment.py
View file @
db310636
import
math
import
numbers
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
import
PIL.Image
import
torch
...
...
@@ -91,6 +91,14 @@ class RandomErasing(_RandomApplyTransform):
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
img_c
,
img_h
,
img_w
=
query_chw
(
flat_inputs
)
...
...
@@ -131,7 +139,7 @@ class RandomErasing(_RandomApplyTransform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"v"
]
is
not
None
:
inpt
=
F
.
erase
(
inpt
,
**
params
,
inplace
=
self
.
inplace
)
inpt
=
self
.
_call_kernel
(
F
.
erase
,
inpt
,
**
params
,
inplace
=
self
.
inplace
)
return
inpt
...
...
torchvision/transforms/v2/_color.py
View file @
db310636
import
collections.abc
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
PIL.Image
import
torch
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
transforms
as
_transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._transform
import
_RandomApplyTransform
from
.utils
import
is_simple_tensor
,
query_chw
from
.utils
import
query_chw
class
Grayscale
(
Transform
):
...
...
@@ -24,19 +23,12 @@ class Grayscale(Transform):
_v1_transform_cls
=
_transforms
.
Grayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
num_output_channels
:
int
=
1
):
super
().
__init__
()
self
.
num_output_channels
=
num_output_channels
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
self
.
num_output_channels
)
return
self
.
_call_kernel
(
F
.
rgb_to_grayscale
,
inpt
,
num_output_channels
=
self
.
num_output_channels
)
class
RandomGrayscale
(
_RandomApplyTransform
):
...
...
@@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomGrayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
p
:
float
=
0.1
)
->
None
:
super
().
__init__
(
p
=
p
)
...
...
@@ -70,7 +55,7 @@ class RandomGrayscale(_RandomApplyTransform):
return
dict
(
num_input_channels
=
num_input_channels
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
return
self
.
_call_kernel
(
F
.
rgb_to_grayscale
,
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
class
ColorJitter
(
Transform
):
...
...
@@ -167,13 +152,13 @@ class ColorJitter(Transform):
hue_factor
=
params
[
"hue_factor"
]
for
fn_id
in
params
[
"fn_idx"
]:
if
fn_id
==
0
and
brightness_factor
is
not
None
:
output
=
F
.
adjust_brightness
(
output
,
brightness_factor
=
brightness_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_brightness
,
output
,
brightness_factor
=
brightness_factor
)
elif
fn_id
==
1
and
contrast_factor
is
not
None
:
output
=
F
.
adjust_contrast
(
output
,
contrast_factor
=
contrast_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_contrast
,
output
,
contrast_factor
=
contrast_factor
)
elif
fn_id
==
2
and
saturation_factor
is
not
None
:
output
=
F
.
adjust_saturation
(
output
,
saturation_factor
=
saturation_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_saturation
,
output
,
saturation_factor
=
saturation_factor
)
elif
fn_id
==
3
and
hue_factor
is
not
None
:
output
=
F
.
adjust_hue
(
output
,
hue_factor
=
hue_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_hue
,
output
,
hue_factor
=
hue_factor
)
return
output
...
...
@@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform):
.. v2betastatus:: RandomChannelPermutation transform
"""
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
num_channels
,
*
_
=
query_chw
(
flat_inputs
)
return
dict
(
permutation
=
torch
.
randperm
(
num_channels
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
permute_channels
(
inpt
,
params
[
"permutation"
])
return
self
.
_call_kernel
(
F
.
permute_channels
,
inpt
,
params
[
"permutation"
])
class
RandomPhotometricDistort
(
Transform
):
...
...
@@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform):
Default is 0.5.
"""
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
brightness
:
Tuple
[
float
,
float
]
=
(
0.875
,
1.125
),
...
...
@@ -263,17 +234,17 @@ class RandomPhotometricDistort(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"brightness_factor"
]
is
not
None
:
inpt
=
F
.
adjust_brightness
(
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_brightness
,
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
inpt
=
F
.
adjust_contrast
(
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_contrast
,
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
if
params
[
"saturation_factor"
]
is
not
None
:
inpt
=
F
.
adjust_saturation
(
inpt
,
saturation_factor
=
params
[
"saturation_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_saturation
,
inpt
,
saturation_factor
=
params
[
"saturation_factor"
])
if
params
[
"hue_factor"
]
is
not
None
:
inpt
=
F
.
adjust_hue
(
inpt
,
hue_factor
=
params
[
"hue_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_hue
,
inpt
,
hue_factor
=
params
[
"hue_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
not
params
[
"contrast_before"
]:
inpt
=
F
.
adjust_contrast
(
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_contrast
,
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
if
params
[
"channel_permutation"
]
is
not
None
:
inpt
=
F
.
permute_channels
(
inpt
,
permutation
=
params
[
"channel_permutation"
])
inpt
=
self
.
_call_kernel
(
F
.
permute_channels
,
inpt
,
permutation
=
params
[
"channel_permutation"
])
return
inpt
...
...
@@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomEqualize
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
equalize
(
inpt
)
return
self
.
_call_kernel
(
F
.
equalize
,
inpt
)
class
RandomInvert
(
_RandomApplyTransform
):
...
...
@@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomInvert
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
invert
(
inpt
)
return
self
.
_call_kernel
(
F
.
invert
,
inpt
)
class
RandomPosterize
(
_RandomApplyTransform
):
...
...
@@ -337,7 +308,7 @@ class RandomPosterize(_RandomApplyTransform):
self
.
bits
=
bits
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
posterize
(
inpt
,
bits
=
self
.
bits
)
return
self
.
_call_kernel
(
F
.
posterize
,
inpt
,
bits
=
self
.
bits
)
class
RandomSolarize
(
_RandomApplyTransform
):
...
...
@@ -362,7 +333,7 @@ class RandomSolarize(_RandomApplyTransform):
self
.
threshold
=
threshold
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
solarize
(
inpt
,
threshold
=
self
.
threshold
)
return
self
.
_call_kernel
(
F
.
solarize
,
inpt
,
threshold
=
self
.
threshold
)
class
RandomAutocontrast
(
_RandomApplyTransform
):
...
...
@@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomAutocontrast
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
autocontrast
(
inpt
)
return
self
.
_call_kernel
(
F
.
autocontrast
,
inpt
)
class
RandomAdjustSharpness
(
_RandomApplyTransform
):
...
...
@@ -406,4 +377,4 @@ class RandomAdjustSharpness(_RandomApplyTransform):
self
.
sharpness_factor
=
sharpness_factor
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
adjust_sharpness
(
inpt
,
sharpness_factor
=
self
.
sharpness_factor
)
return
self
.
_call_kernel
(
F
.
adjust_sharpness
,
inpt
,
sharpness_factor
=
self
.
sharpness_factor
)
torchvision/transforms/v2/_geometry.py
View file @
db310636
import
math
import
numbers
import
warnings
from
typing
import
Any
,
cast
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
...
...
@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomHorizontalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
horizontal_flip
(
inpt
)
return
self
.
_call_kernel
(
F
.
horizontal_flip
,
inpt
)
class
RandomVerticalFlip
(
_RandomApplyTransform
):
...
...
@@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomVerticalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
vertical_flip
(
inpt
)
return
self
.
_call_kernel
(
F
.
vertical_flip
,
inpt
)
class
Resize
(
Transform
):
...
...
@@ -152,7 +152,8 @@ class Resize(Transform):
self
.
antialias
=
antialias
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
self
.
size
,
interpolation
=
self
.
interpolation
,
...
...
@@ -186,7 +187,7 @@ class CenterCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
center_crop
(
inpt
,
output_size
=
self
.
size
)
return
self
.
_call_kernel
(
F
.
center_crop
,
inpt
,
output_size
=
self
.
size
)
class
RandomResizedCrop
(
Transform
):
...
...
@@ -307,8 +308,8 @@ class RandomResizedCrop(Transform):
return
dict
(
top
=
i
,
left
=
j
,
height
=
h
,
width
=
w
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resized_crop
(
inpt
,
**
params
,
size
=
self
.
size
,
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
return
self
.
_call_kernel
(
F
.
resized_crop
,
inpt
,
**
params
,
size
=
self
.
size
,
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
...
...
@@ -357,8 +358,16 @@ class FiveCrop(Transform):
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
five_crop
(
inpt
,
self
.
size
)
return
self
.
_call_kernel
(
F
.
five_crop
,
inpt
,
self
.
size
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
...
...
@@ -396,12 +405,20 @@ class TenCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
vertical_flip
=
vertical_flip
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
raise
TypeError
(
f
"BoundingBoxes'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
ten_crop
(
inpt
,
self
.
size
,
vertical_flip
=
self
.
vertical_flip
)
return
self
.
_call_kernel
(
F
.
ten_crop
,
inpt
,
self
.
size
,
vertical_flip
=
self
.
vertical_flip
)
class
Pad
(
Transform
):
...
...
@@ -475,7 +492,7 @@ class Pad(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
return
self
.
_call_kernel
(
F
.
pad
,
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
class
RandomZoomOut
(
_RandomApplyTransform
):
...
...
@@ -545,7 +562,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
return
self
.
_call_kernel
(
F
.
pad
,
inpt
,
**
params
,
fill
=
fill
)
class
RandomRotation
(
Transform
):
...
...
@@ -611,7 +628,8 @@ class RandomRotation(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
rotate
(
return
self
.
_call_kernel
(
F
.
rotate
,
inpt
,
**
params
,
interpolation
=
self
.
interpolation
,
...
...
@@ -733,7 +751,8 @@ class RandomAffine(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
affine
(
return
self
.
_call_kernel
(
F
.
affine
,
inpt
,
**
params
,
interpolation
=
self
.
interpolation
,
...
...
@@ -889,10 +908,12 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
self
.
_call_kernel
(
F
.
pad
,
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
inpt
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
inpt
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]
)
return
inpt
...
...
@@ -973,7 +994,8 @@ class RandomPerspective(_RandomApplyTransform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
perspective
(
return
self
.
_call_kernel
(
F
.
perspective
,
inpt
,
None
,
None
,
...
...
@@ -1050,7 +1072,7 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd
if
kx
%
2
==
0
:
kx
+=
1
dx
=
F
.
gaussian_blur
(
dx
,
[
kx
,
kx
],
list
(
self
.
sigma
))
dx
=
self
.
_call_kernel
(
F
.
gaussian_blur
,
dx
,
[
kx
,
kx
],
list
(
self
.
sigma
))
dx
=
dx
*
self
.
alpha
[
0
]
/
size
[
0
]
dy
=
torch
.
rand
([
1
,
1
]
+
size
)
*
2
-
1
...
...
@@ -1059,14 +1081,15 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd
if
ky
%
2
==
0
:
ky
+=
1
dy
=
F
.
gaussian_blur
(
dy
,
[
ky
,
ky
],
list
(
self
.
sigma
))
dy
=
self
.
_call_kernel
(
F
.
gaussian_blur
,
dy
,
[
ky
,
ky
],
list
(
self
.
sigma
))
dy
=
dy
*
self
.
alpha
[
1
]
/
size
[
1
]
displacement
=
torch
.
concat
([
dx
,
dy
],
1
).
permute
([
0
,
2
,
3
,
1
])
# 1 x H x W x 2
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
elastic
(
return
self
.
_call_kernel
(
F
.
elastic
,
inpt
,
**
params
,
fill
=
fill
,
...
...
@@ -1164,7 +1187,9 @@ class RandomIoUCrop(Transform):
# check for any valid boxes with centers within the crop area
xyxy_bboxes
=
F
.
convert_format_bounding_boxes
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
)
cx
=
0.5
*
(
xyxy_bboxes
[...,
0
]
+
xyxy_bboxes
[...,
2
])
cy
=
0.5
*
(
xyxy_bboxes
[...,
1
]
+
xyxy_bboxes
[...,
3
])
...
...
@@ -1188,7 +1213,9 @@ class RandomIoUCrop(Transform):
if
len
(
params
)
<
1
:
return
inpt
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
output
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]
)
if
isinstance
(
output
,
datapoints
.
BoundingBoxes
):
# We "mark" the invalid boxes as degenreate, and they can be
...
...
@@ -1262,7 +1289,9 @@ class ScaleJitter(Transform):
return
dict
(
size
=
(
new_height
,
new_width
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
RandomShortestSize
(
Transform
):
...
...
@@ -1330,7 +1359,9 @@ class RandomShortestSize(Transform):
return
dict
(
size
=
(
new_height
,
new_width
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
RandomResize
(
Transform
):
...
...
@@ -1400,4 +1431,6 @@ class RandomResize(Transform):
return
dict
(
size
=
[
size
])
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
inpt
,
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
torchvision/transforms/v2/_misc.py
View file @
db310636
...
...
@@ -106,7 +106,7 @@ class LinearTransformation(Transform):
def
_check_inputs
(
self
,
sample
:
Any
)
->
Any
:
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
raise
TypeError
(
"LinearTransformation
does not
work on
PIL
I
mages"
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
()
does not
support
PIL
i
mages
.
"
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
shape
=
inpt
.
shape
...
...
@@ -157,7 +157,6 @@ class Normalize(Transform):
"""
_v1_transform_cls
=
_transforms
.
Normalize
_transformed_types
=
(
datapoints
.
Image
,
is_simple_tensor
,
datapoints
.
Video
)
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
],
inplace
:
bool
=
False
):
super
().
__init__
()
...
...
@@ -170,7 +169,7 @@ class Normalize(Transform):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
return
self
.
_call_kernel
(
F
.
normalize
,
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
class
GaussianBlur
(
Transform
):
...
...
@@ -217,7 +216,7 @@ class GaussianBlur(Transform):
return
dict
(
sigma
=
[
sigma
,
sigma
])
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
gaussian_blur
(
inpt
,
self
.
kernel_size
,
**
params
)
return
self
.
_call_kernel
(
F
.
gaussian_blur
,
inpt
,
self
.
kernel_size
,
**
params
)
class
ToDtype
(
Transform
):
...
...
@@ -290,7 +289,7 @@ class ToDtype(Transform):
)
return
inpt
return
F
.
to_dtype
(
inpt
,
dtype
=
dtype
,
scale
=
self
.
scale
)
return
self
.
_call_kernel
(
F
.
to_dtype
,
inpt
,
dtype
=
dtype
,
scale
=
self
.
scale
)
class
ConvertImageDtype
(
Transform
):
...
...
@@ -320,14 +319,12 @@ class ConvertImageDtype(Transform):
_v1_transform_cls
=
_transforms
.
ConvertImageDtype
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
)
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
super
().
__init__
()
self
.
dtype
=
dtype
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
to_dtype
(
inpt
,
dtype
=
self
.
dtype
,
scale
=
True
)
return
self
.
_call_kernel
(
F
.
to_dtype
,
inpt
,
dtype
=
self
.
dtype
,
scale
=
True
)
class
SanitizeBoundingBoxes
(
Transform
):
...
...
torchvision/transforms/v2/_temporal.py
View file @
db310636
...
...
@@ -25,4 +25,4 @@ class UniformTemporalSubsample(Transform):
self
.
num_samples
=
num_samples
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
uniform_temporal_subsample
(
inpt
,
self
.
num_samples
)
return
self
.
_call_kernel
(
F
.
uniform_temporal_subsample
,
inpt
,
self
.
num_samples
)
torchvision/transforms/v2/_transform.py
View file @
db310636
...
...
@@ -11,6 +11,8 @@ from torchvision import datapoints
from
torchvision.transforms.v2.utils
import
check_type
,
has_any
,
is_simple_tensor
from
torchvision.utils
import
_log_api_usage_once
from
.functional._utils
import
_get_kernel
class
Transform
(
nn
.
Module
):
...
...
@@ -28,6 +30,10 @@ class Transform(nn.Module):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
return
dict
()
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
kernel
=
_get_kernel
(
dispatcher
,
type
(
inpt
),
allow_passthrough
=
True
)
return
kernel
(
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
raise
NotImplementedError
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
db310636
...
...
@@ -5,10 +5,9 @@ from torchvision import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_kernel_internal
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
def
erase
(
inpt
:
torch
.
Tensor
,
i
:
int
,
...
...
torchvision/transforms/v2/functional/_color.py
View file @
db310636
...
...
@@ -10,12 +10,10 @@ from torchvision.transforms._functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._type_conversion
import
pil_to_tensor
,
to_image_pil
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_kernel_internal
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
...
...
@@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
return
output
if
fp
else
output
.
to
(
image1
.
dtype
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
...
...
@@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return
adjust_brightness_image_tensor
(
video
,
brightness_factor
=
brightness_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
...
...
@@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return
adjust_saturation_image_tensor
(
video
,
saturation_factor
=
saturation_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
...
...
@@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return
adjust_contrast_image_tensor
(
video
,
contrast_factor
=
contrast_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
...
...
@@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return
adjust_sharpness_image_tensor
(
video
,
sharpness_factor
=
sharpness_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
...
...
@@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return
adjust_hue_image_tensor
(
video
,
hue_factor
=
hue_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
...
...
@@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return
adjust_gamma_image_tensor
(
video
,
gamma
=
gamma
,
gain
=
gain
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
...
...
@@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return
posterize_image_tensor
(
video
,
bits
=
bits
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
...
...
@@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return
solarize_image_tensor
(
video
,
threshold
=
threshold
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image_tensor
(
inpt
)
...
...
@@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return
autocontrast_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
equalize_image_tensor
(
inpt
)
...
...
@@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return
equalize_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
invert_image_tensor
(
inpt
)
...
...
@@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return
invert_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
"""Permute the channels of the input according to the given permutation.
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
db310636
...
...
@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
(
_FillTypeJIT
,
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
)
from
._utils
import
_FillTypeJIT
,
_get_kernel
,
_register_five_ten_crop_kernel_internal
,
_register_kernel_internal
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
...
@@ -2203,7 +2197,6 @@ def resized_crop_video(
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
five_crop
(
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return
size
@
_register_five_ten_crop_kernel
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
datapoints
.
Image
)
def
five_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -2250,7 +2243,7 @@ def five_crop_image_tensor(
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel
(
five_crop
,
PIL
.
Image
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
PIL
.
Image
.
Image
)
def
five_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
...
...
@@ -2269,14 +2262,13 @@ def five_crop_image_pil(
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Video
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
datapoints
.
Video
)
def
five_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
five_crop_image_tensor
(
video
,
size
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
ten_crop
(
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
@@ -2300,8 +2292,8 @@ def ten_crop(
return
kernel
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
datapoints
.
Image
)
def
ten_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
@@ -2328,7 +2320,7 @@ def ten_crop_image_tensor(
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel
(
ten_crop
,
PIL
.
Image
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
PIL
.
Image
.
Image
)
def
ten_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
@@ -2355,7 +2347,7 @@ def ten_crop_image_pil(
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Video
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
datapoints
.
Video
)
def
ten_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
db310636
...
...
@@ -8,10 +8,9 @@ from torchvision.transforms import _functional_pil as _FP
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
_register_unsupported_type
,
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
is_simple_tensor
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image_tensor
(
inpt
)
...
...
@@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
return
get_dimensions_image_tensor
(
video
)
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image_tensor
(
inpt
)
...
...
@@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
return
list
(
bounding_box
.
canvas_size
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
return
get_num_frames_video
(
inpt
)
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
db310636
...
...
@@ -11,11 +11,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_
explicit_noop
,
_register_kernel_internal
,
_register_unsupported_type
from
._utils
import
_get_kernel
,
_register_
kernel_internal
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
normalize
(
inpt
:
torch
.
Tensor
,
mean
:
List
[
float
],
...
...
@@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
return
normalize_image_tensor
(
video
,
mean
,
std
,
inplace
=
inplace
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
...
@@ -182,7 +179,6 @@ def gaussian_blur_video(
return
gaussian_blur_image_tensor
(
video
,
kernel_size
,
sigma
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
...
...
torchvision/transforms/v2/functional/_temporal.py
View file @
db310636
import
PIL.Image
import
torch
from
torchvision
import
datapoints
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_kernel_internal
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
uniform_temporal_subsample
(
inpt
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
db310636
import
functools
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
import
torch
...
...
@@ -53,6 +52,11 @@ def _name_to_dispatcher(name):
)
from
None
_BUILTIN_DATAPOINT_TYPES
=
{
obj
for
obj
in
datapoints
.
__dict__
.
values
()
if
isinstance
(
obj
,
type
)
and
issubclass
(
obj
,
datapoints
.
Datapoint
)
}
def
register_kernel
(
dispatcher
,
datapoint_cls
):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
...
...
@@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls):
f
"but got
{
dispatcher
}
."
)
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)
and
datapoint_cls
is
not
datapoints
.
Datapoint
):
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)):
raise
ValueError
(
f
"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f
"but got
{
datapoint_cls
}
."
)
if
datapoint_cls
in
_BUILTIN_DATAPOINT_TYPES
:
raise
ValueError
(
f
"Kernels cannot be registered for the builtin datapoint classes, but got
{
datapoint_cls
}
"
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
def
_get_kernel
(
dispatcher
,
input_type
):
def
_get_kernel
(
dispatcher
,
input_type
,
*
,
allow_passthrough
=
False
):
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for dispatcher
{
dispatcher
.
__name__
}
."
)
...
...
@@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type):
elif
cls
in
registry
:
return
registry
[
cls
]
# Note that in the future we are not going to return a noop here, but rather raise the error below
return
_noop
if
allow_passthrough
:
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
raise
TypeError
(
f
"Dispatcher
{
dispatcher
}
supports inputs of type torch.Tensor, PIL.Image.Image, "
f
"and subclasses of torchvision.datapoints.Datapoint, "
f
"Dispatcher F.
{
dispatcher
.
__name__
}
supports inputs of type
{
registry
.
keys
()
}
, "
f
"but got
{
input_type
}
instead."
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def
_register_explicit_noop
(
*
datapoints_classes
,
warn_passthrough
=
False
):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
def
decorator
(
dispatcher
):
for
cls
in
datapoints_classes
:
msg
=
(
f
"F.
{
dispatcher
.
__name__
}
is currently passing through inputs of type datapoints.
{
cls
.
__name__
}
. "
f
"This will likely change in the future."
)
_register_kernel_internal
(
dispatcher
,
cls
,
datapoint_wrapper
=
False
)(
functools
.
partial
(
_noop
,
__msg__
=
msg
if
warn_passthrough
else
None
)
)
return
dispatcher
return
decorator
def
_noop
(
inpt
,
*
args
,
__msg__
=
None
,
**
kwargs
):
if
__msg__
:
warnings
.
warn
(
__msg__
,
UserWarning
,
stacklevel
=
2
)
return
inpt
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def
_register_unsupported_type
(
*
input_types
):
def
kernel
(
inpt
,
*
args
,
__dispatcher_name__
,
**
kwargs
):
raise
TypeError
(
f
"F.
{
__dispatcher_name__
}
does not support inputs of type
{
type
(
inpt
)
}
."
)
def
decorator
(
dispatcher
):
for
input_type
in
input_types
:
_register_kernel_internal
(
dispatcher
,
input_type
,
datapoint_wrapper
=
False
)(
functools
.
partial
(
kernel
,
__dispatcher_name__
=
dispatcher
.
__name__
)
)
return
dispatcher
return
decorator
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
def
_register_five_ten_crop_kernel
(
dispatcher
,
input_type
):
def
_register_five_ten_crop_kernel
_internal
(
dispatcher
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
if
input_type
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
}
' already has a kernel registered for type '
{
input_type
}
'."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment