Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
db310636
Unverified
Commit
db310636
authored
Aug 10, 2023
by
Philip Meier
Committed by
GitHub
Aug 10, 2023
Browse files
move passthrough for unknown types from dispatchers to transforms (#7804)
parent
87681314
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
150 additions
and
731 deletions
+150
-731
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+0
-63
test/test_transforms_v2.py
test/test_transforms_v2.py
+2
-360
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+0
-62
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+16
-40
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+3
-2
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+10
-2
torchvision/transforms/v2/_color.py
torchvision/transforms/v2/_color.py
+21
-50
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+57
-24
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+5
-8
torchvision/transforms/v2/_temporal.py
torchvision/transforms/v2/_temporal.py
+1
-1
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+6
-0
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+1
-2
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+2
-15
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+9
-17
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+1
-4
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+1
-5
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+1
-5
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+14
-71
No files found.
test/test_prototype_transforms.py
View file @
db310636
import
itertools
import
re
import
re
import
PIL.Image
import
PIL.Image
...
@@ -19,7 +17,6 @@ from prototype_common_utils import make_label
...
@@ -19,7 +17,6 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
_convert_fill_arg
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_image_pil
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_image_pil
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
...
@@ -187,66 +184,6 @@ class TestFixedSizeCrop:
...
@@ -187,66 +184,6 @@ class TestFixedSizeCrop:
assert
params
[
"needs_pad"
]
assert
params
[
"needs_pad"
]
assert
any
(
pad
>
0
for
pad
in
params
[
"padding"
])
assert
any
(
pad
>
0
for
pad
in
params
[
"padding"
])
@
pytest
.
mark
.
parametrize
(
"needs"
,
list
(
itertools
.
product
((
False
,
True
),
repeat
=
2
)))
def
test__transform
(
self
,
mocker
,
needs
):
fill_sentinel
=
12
padding_mode_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
),
fill
=
fill_sentinel
,
padding_mode
=
padding_mode_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
needs_crop
,
needs_pad
=
needs
top_sentinel
=
mocker
.
MagicMock
()
left_sentinel
=
mocker
.
MagicMock
()
height_sentinel
=
mocker
.
MagicMock
()
width_sentinel
=
mocker
.
MagicMock
()
is_valid
=
mocker
.
MagicMock
()
if
needs_crop
else
None
padding_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params"
,
return_value
=
dict
(
needs_crop
=
needs_crop
,
top
=
top_sentinel
,
left
=
left_sentinel
,
height
=
height_sentinel
,
width
=
width_sentinel
,
is_valid
=
is_valid
,
padding
=
padding_sentinel
,
needs_pad
=
needs_pad
,
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock_crop
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.crop"
)
mock_pad
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.pad"
)
transform
(
inpt_sentinel
)
if
needs_crop
:
mock_crop
.
assert_called_once_with
(
inpt_sentinel
,
top
=
top_sentinel
,
left
=
left_sentinel
,
height
=
height_sentinel
,
width
=
width_sentinel
,
)
else
:
mock_crop
.
assert_not_called
()
if
needs_pad
:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad
.
assert_called_once
()
args
,
kwargs
=
mock_pad
.
call_args
if
not
needs_crop
:
assert
args
[
0
]
is
inpt_sentinel
assert
args
[
1
]
is
padding_sentinel
fill_sentinel
=
_convert_fill_arg
(
fill_sentinel
)
assert
kwargs
==
dict
(
fill
=
fill_sentinel
,
padding_mode
=
padding_mode_sentinel
)
else
:
mock_pad
.
assert_not_called
()
def
test__transform_culling
(
self
,
mocker
):
def
test__transform_culling
(
self
,
mocker
):
batch_size
=
10
batch_size
=
10
canvas_size
=
(
10
,
10
)
canvas_size
=
(
10
,
10
)
...
...
test/test_transforms_v2.py
View file @
db310636
...
@@ -27,7 +27,7 @@ from common_utils import (
...
@@ -27,7 +27,7 @@ from common_utils import (
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.ops.boxes
import
box_iou
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
InterpolationMode
,
to_pil_image
from
torchvision.transforms.functional
import
to_pil_image
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
,
query_chw
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
,
query_chw
...
@@ -419,46 +419,6 @@ class TestPad:
...
@@ -419,46 +419,6 @@ class TestPad:
with
pytest
.
raises
(
ValueError
,
match
=
"Padding mode should be either"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Padding mode should be either"
):
transforms
.
Pad
(
12
,
padding_mode
=
"abc"
)
transforms
.
Pad
(
12
,
padding_mode
=
"abc"
)
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
1
,
(
1
,
2
),
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
0
,
[
1
,
2
,
3
],
(
2
,
3
,
4
)])
@
pytest
.
mark
.
parametrize
(
"padding_mode"
,
[
"constant"
,
"edge"
])
def
test__transform
(
self
,
padding
,
fill
,
padding_mode
,
mocker
):
transform
=
transforms
.
Pad
(
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
_
=
transform
(
inpt
)
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
if
isinstance
(
padding
,
tuple
):
padding
=
list
(
padding
)
fn
.
assert_called_once_with
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
12
,
{
datapoints
.
Image
:
12
,
datapoints
.
Mask
:
34
}])
def
test__transform_image_mask
(
self
,
fill
,
mocker
):
transform
=
transforms
.
Pad
(
1
,
fill
=
fill
,
padding_mode
=
"constant"
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
32
,
32
))
mask
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
5
,
size
=
(
32
,
32
)))
inpt
=
[
image
,
mask
]
_
=
transform
(
inpt
)
if
isinstance
(
fill
,
int
):
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
calls
=
[
mocker
.
call
(
image
,
padding
=
1
,
fill
=
fill
,
padding_mode
=
"constant"
),
mocker
.
call
(
mask
,
padding
=
1
,
fill
=
fill
,
padding_mode
=
"constant"
),
]
else
:
fill_img
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
image
)])
fill_mask
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
mask
)])
calls
=
[
mocker
.
call
(
image
,
padding
=
1
,
fill
=
fill_img
,
padding_mode
=
"constant"
),
mocker
.
call
(
mask
,
padding
=
1
,
fill
=
fill_mask
,
padding_mode
=
"constant"
),
]
fn
.
assert_has_calls
(
calls
)
class
TestRandomZoomOut
:
class
TestRandomZoomOut
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -487,56 +447,6 @@ class TestRandomZoomOut:
...
@@ -487,56 +447,6 @@ class TestRandomZoomOut:
assert
0
<=
params
[
"padding"
][
2
]
<=
(
side_range
[
1
]
-
1
)
*
w
assert
0
<=
params
[
"padding"
][
2
]
<=
(
side_range
[
1
]
-
1
)
*
w
assert
0
<=
params
[
"padding"
][
3
]
<=
(
side_range
[
1
]
-
1
)
*
h
assert
0
<=
params
[
"padding"
][
3
]
<=
(
side_range
[
1
]
-
1
)
*
h
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
0
,
[
1
,
2
,
3
],
(
2
,
3
,
4
)])
@
pytest
.
mark
.
parametrize
(
"side_range"
,
[(
1.0
,
4.0
),
[
2.0
,
5.0
]])
def
test__transform
(
self
,
fill
,
side_range
,
mocker
):
inpt
=
make_image
((
24
,
32
))
transform
=
transforms
.
RandomZoomOut
(
fill
=
fill
,
side_range
=
side_range
,
p
=
1
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
torch
.
rand
(
1
)
# random apply changes random state
params
=
transform
.
_get_params
([
inpt
])
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
fn
.
assert_called_once_with
(
inpt
,
**
params
,
fill
=
fill
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
12
,
{
datapoints
.
Image
:
12
,
datapoints
.
Mask
:
34
}])
def
test__transform_image_mask
(
self
,
fill
,
mocker
):
transform
=
transforms
.
RandomZoomOut
(
fill
=
fill
,
p
=
1.0
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
)
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
32
,
32
))
mask
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
5
,
size
=
(
32
,
32
)))
inpt
=
[
image
,
mask
]
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
torch
.
rand
(
1
)
# random apply changes random state
params
=
transform
.
_get_params
(
inpt
)
if
isinstance
(
fill
,
int
):
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
calls
=
[
mocker
.
call
(
image
,
**
params
,
fill
=
fill
),
mocker
.
call
(
mask
,
**
params
,
fill
=
fill
),
]
else
:
fill_img
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
image
)])
fill_mask
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
[
type
(
mask
)])
calls
=
[
mocker
.
call
(
image
,
**
params
,
fill
=
fill_img
),
mocker
.
call
(
mask
,
**
params
,
fill
=
fill_mask
),
]
fn
.
assert_has_calls
(
calls
)
class
TestRandomCrop
:
class
TestRandomCrop
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -599,51 +509,6 @@ class TestRandomCrop:
...
@@ -599,51 +509,6 @@ class TestRandomCrop:
assert
params
[
"needs_pad"
]
is
any
(
padding
)
assert
params
[
"needs_pad"
]
is
any
(
padding
)
assert
params
[
"padding"
]
==
padding
assert
params
[
"padding"
]
==
padding
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
None
,
1
,
[
2
,
3
],
[
1
,
2
,
3
,
4
]])
@
pytest
.
mark
.
parametrize
(
"pad_if_needed"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"padding_mode"
,
[
"constant"
,
"edge"
])
def
test__transform
(
self
,
padding
,
pad_if_needed
,
fill
,
padding_mode
,
mocker
):
output_size
=
[
10
,
12
]
transform
=
transforms
.
RandomCrop
(
output_size
,
padding
=
padding
,
pad_if_needed
=
pad_if_needed
,
fill
=
fill
,
padding_mode
=
padding_mode
)
h
,
w
=
size
=
(
32
,
32
)
inpt
=
make_image
(
size
)
if
isinstance
(
padding
,
int
):
new_size
=
(
h
+
padding
,
w
+
padding
)
elif
isinstance
(
padding
,
list
):
new_size
=
(
h
+
sum
(
padding
[
0
::
2
]),
w
+
sum
(
padding
[
1
::
2
]))
else
:
new_size
=
size
expected
=
make_image
(
new_size
)
_
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.pad"
,
return_value
=
expected
)
fn_crop
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.crop"
)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
params
=
transform
.
_get_params
([
inpt
])
if
padding
is
None
and
not
pad_if_needed
:
fn_crop
.
assert_called_once_with
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
output_size
[
0
],
width
=
output_size
[
1
]
)
elif
not
pad_if_needed
:
fn_crop
.
assert_called_once_with
(
expected
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
output_size
[
0
],
width
=
output_size
[
1
]
)
elif
padding
is
None
:
# vfdev-5: I do not know how to mock and test this case
pass
else
:
# vfdev-5: I do not know how to mock and test this case
pass
class
TestGaussianBlur
:
class
TestGaussianBlur
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -675,62 +540,6 @@ class TestGaussianBlur:
...
@@ -675,62 +540,6 @@ class TestGaussianBlur:
assert
sigma
[
0
]
<=
params
[
"sigma"
][
0
]
<=
sigma
[
1
]
assert
sigma
[
0
]
<=
params
[
"sigma"
][
0
]
<=
sigma
[
1
]
assert
sigma
[
0
]
<=
params
[
"sigma"
][
1
]
<=
sigma
[
1
]
assert
sigma
[
0
]
<=
params
[
"sigma"
][
1
]
<=
sigma
[
1
]
@
pytest
.
mark
.
parametrize
(
"kernel_size"
,
[
3
,
[
3
,
5
],
(
5
,
3
)])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[
2.0
,
[
2.0
,
3.0
]])
def
test__transform
(
self
,
kernel_size
,
sigma
,
mocker
):
transform
=
transforms
.
GaussianBlur
(
kernel_size
=
kernel_size
,
sigma
=
sigma
)
if
isinstance
(
kernel_size
,
(
tuple
,
list
)):
assert
transform
.
kernel_size
==
kernel_size
else
:
kernel_size
=
(
kernel_size
,
kernel_size
)
assert
transform
.
kernel_size
==
kernel_size
if
isinstance
(
sigma
,
(
tuple
,
list
)):
assert
transform
.
sigma
==
sigma
else
:
assert
transform
.
sigma
==
[
sigma
,
sigma
]
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.gaussian_blur"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
inpt
.
num_channels
=
3
inpt
.
canvas_size
=
(
24
,
32
)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
params
=
transform
.
_get_params
([
inpt
])
fn
.
assert_called_once_with
(
inpt
,
kernel_size
,
**
params
)
class
TestRandomColorOp
:
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"transform_cls, func_op_name, kwargs"
,
[
(
transforms
.
RandomEqualize
,
"equalize"
,
{}),
(
transforms
.
RandomInvert
,
"invert"
,
{}),
(
transforms
.
RandomAutocontrast
,
"autocontrast"
,
{}),
(
transforms
.
RandomPosterize
,
"posterize"
,
{
"bits"
:
4
}),
(
transforms
.
RandomSolarize
,
"solarize"
,
{
"threshold"
:
0.5
}),
(
transforms
.
RandomAdjustSharpness
,
"adjust_sharpness"
,
{
"sharpness_factor"
:
0.5
}),
],
)
def
test__transform
(
self
,
p
,
transform_cls
,
func_op_name
,
kwargs
,
mocker
):
transform
=
transform_cls
(
p
=
p
,
**
kwargs
)
fn
=
mocker
.
patch
(
f
"torchvision.transforms.v2.functional.
{
func_op_name
}
"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
_
=
transform
(
inpt
)
if
p
>
0.0
:
fn
.
assert_called_once_with
(
inpt
,
**
kwargs
)
else
:
assert
fn
.
call_count
==
0
class
TestRandomPerspective
:
class
TestRandomPerspective
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -751,28 +560,6 @@ class TestRandomPerspective:
...
@@ -751,28 +560,6 @@ class TestRandomPerspective:
assert
"coefficients"
in
params
assert
"coefficients"
in
params
assert
len
(
params
[
"coefficients"
])
==
8
assert
len
(
params
[
"coefficients"
])
==
8
@
pytest
.
mark
.
parametrize
(
"distortion_scale"
,
[
0.1
,
0.7
])
def
test__transform
(
self
,
distortion_scale
,
mocker
):
interpolation
=
InterpolationMode
.
BILINEAR
fill
=
12
transform
=
transforms
.
RandomPerspective
(
distortion_scale
,
fill
=
fill
,
interpolation
=
interpolation
)
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.perspective"
)
inpt
=
make_image
((
24
,
32
))
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch
.
manual_seed
(
12
)
_
=
transform
(
inpt
)
torch
.
manual_seed
(
12
)
torch
.
rand
(
1
)
# random apply changes random state
params
=
transform
.
_get_params
([
inpt
])
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
fn
.
assert_called_once_with
(
inpt
,
None
,
None
,
**
params
,
fill
=
fill
,
interpolation
=
interpolation
)
class
TestElasticTransform
:
class
TestElasticTransform
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -813,35 +600,6 @@ class TestElasticTransform:
...
@@ -813,35 +600,6 @@ class TestElasticTransform:
assert
(
-
alpha
/
w
<=
displacement
[
0
,
...,
0
]).
all
()
and
(
displacement
[
0
,
...,
0
]
<=
alpha
/
w
).
all
()
assert
(
-
alpha
/
w
<=
displacement
[
0
,
...,
0
]).
all
()
and
(
displacement
[
0
,
...,
0
]
<=
alpha
/
w
).
all
()
assert
(
-
alpha
/
h
<=
displacement
[
0
,
...,
1
]).
all
()
and
(
displacement
[
0
,
...,
1
]
<=
alpha
/
h
).
all
()
assert
(
-
alpha
/
h
<=
displacement
[
0
,
...,
1
]).
all
()
and
(
displacement
[
0
,
...,
1
]
<=
alpha
/
h
).
all
()
@
pytest
.
mark
.
parametrize
(
"alpha"
,
[
5.0
,
[
5.0
,
10.0
]])
@
pytest
.
mark
.
parametrize
(
"sigma"
,
[
2.0
,
[
2.0
,
5.0
]])
def
test__transform
(
self
,
alpha
,
sigma
,
mocker
):
interpolation
=
InterpolationMode
.
BILINEAR
fill
=
12
transform
=
transforms
.
ElasticTransform
(
alpha
,
sigma
=
sigma
,
fill
=
fill
,
interpolation
=
interpolation
)
if
isinstance
(
alpha
,
float
):
assert
transform
.
alpha
==
[
alpha
,
alpha
]
else
:
assert
transform
.
alpha
==
alpha
if
isinstance
(
sigma
,
float
):
assert
transform
.
sigma
==
[
sigma
,
sigma
]
else
:
assert
transform
.
sigma
==
sigma
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.elastic"
)
inpt
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
inpt
.
num_channels
=
3
inpt
.
canvas_size
=
(
24
,
32
)
# Let's mock transform._get_params to control the output:
transform
.
_get_params
=
mocker
.
MagicMock
()
_
=
transform
(
inpt
)
params
=
transform
.
_get_params
([
inpt
])
fill
=
transforms
.
_utils
.
_convert_fill_arg
(
fill
)
fn
.
assert_called_once_with
(
inpt
,
**
params
,
fill
=
fill
,
interpolation
=
interpolation
)
class
TestRandomErasing
:
class
TestRandomErasing
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -889,40 +647,6 @@ class TestRandomErasing:
...
@@ -889,40 +647,6 @@ class TestRandomErasing:
assert
0
<=
i
<=
height
-
h
assert
0
<=
i
<=
height
-
h
assert
0
<=
j
<=
width
-
w
assert
0
<=
j
<=
width
-
w
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
1
])
def
test__transform
(
self
,
mocker
,
p
):
transform
=
transforms
.
RandomErasing
(
p
=
p
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
i_sentinel
=
mocker
.
MagicMock
()
j_sentinel
=
mocker
.
MagicMock
()
h_sentinel
=
mocker
.
MagicMock
()
w_sentinel
=
mocker
.
MagicMock
()
v_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._augment.RandomErasing._get_params"
,
return_value
=
dict
(
i
=
i_sentinel
,
j
=
j_sentinel
,
h
=
h_sentinel
,
w
=
w_sentinel
,
v
=
v_sentinel
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._augment.F.erase"
)
output
=
transform
(
inpt_sentinel
)
if
p
:
mock
.
assert_called_once_with
(
inpt_sentinel
,
i
=
i_sentinel
,
j
=
j_sentinel
,
h
=
h_sentinel
,
w
=
w_sentinel
,
v
=
v_sentinel
,
inplace
=
transform
.
inplace
,
)
else
:
mock
.
assert_not_called
()
assert
output
is
inpt_sentinel
class
TestTransform
:
class
TestTransform
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1111,23 +835,12 @@ class TestRandomIoUCrop:
...
@@ -1111,23 +835,12 @@ class TestRandomIoUCrop:
sample
=
[
image
,
bboxes
,
masks
]
sample
=
[
image
,
bboxes
,
masks
]
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.crop"
,
side_effect
=
lambda
x
,
**
params
:
x
)
is_within_crop_area
=
torch
.
tensor
([
0
,
1
,
0
,
1
,
0
,
1
],
dtype
=
torch
.
bool
)
is_within_crop_area
=
torch
.
tensor
([
0
,
1
,
0
,
1
,
0
,
1
],
dtype
=
torch
.
bool
)
params
=
dict
(
top
=
1
,
left
=
2
,
height
=
12
,
width
=
12
,
is_within_crop_area
=
is_within_crop_area
)
params
=
dict
(
top
=
1
,
left
=
2
,
height
=
12
,
width
=
12
,
is_within_crop_area
=
is_within_crop_area
)
transform
.
_get_params
=
mocker
.
MagicMock
(
return_value
=
params
)
transform
.
_get_params
=
mocker
.
MagicMock
(
return_value
=
params
)
output
=
transform
(
sample
)
output
=
transform
(
sample
)
assert
fn
.
call_count
==
3
expected_calls
=
[
mocker
.
call
(
image
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]),
mocker
.
call
(
bboxes
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]),
mocker
.
call
(
masks
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]),
]
fn
.
assert_has_calls
(
expected_calls
)
# check number of bboxes vs number of labels:
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBoxes
)
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBoxes
)
...
@@ -1164,29 +877,6 @@ class TestScaleJitter:
...
@@ -1164,29 +877,6 @@ class TestScaleJitter:
assert
int
(
canvas_size
[
0
]
*
r_min
)
<=
height
<=
int
(
canvas_size
[
0
]
*
r_max
)
assert
int
(
canvas_size
[
0
]
*
r_min
)
<=
height
<=
int
(
canvas_size
[
0
]
*
r_max
)
assert
int
(
canvas_size
[
1
]
*
r_min
)
<=
width
<=
int
(
canvas_size
[
1
]
*
r_max
)
assert
int
(
canvas_size
[
1
]
*
r_min
)
<=
width
<=
int
(
canvas_size
[
1
]
*
r_max
)
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
ScaleJitter
(
target_size
=
(
16
,
12
),
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
size_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._geometry.ScaleJitter._get_params"
,
return_value
=
dict
(
size
=
size_sentinel
)
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
transform
(
inpt_sentinel
)
mock
.
assert_called_once_with
(
inpt_sentinel
,
size
=
size_sentinel
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
class
TestRandomShortestSize
:
class
TestRandomShortestSize
:
@
pytest
.
mark
.
parametrize
(
"min_size,max_size"
,
[([
5
,
9
],
20
),
([
5
,
9
],
None
)])
@
pytest
.
mark
.
parametrize
(
"min_size,max_size"
,
[([
5
,
9
],
20
),
([
5
,
9
],
None
)])
...
@@ -1211,30 +901,6 @@ class TestRandomShortestSize:
...
@@ -1211,30 +901,6 @@ class TestRandomShortestSize:
else
:
else
:
assert
shorter
in
min_size
assert
shorter
in
min_size
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
RandomShortestSize
(
min_size
=
[
3
,
5
,
7
],
max_size
=
12
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
size_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._geometry.RandomShortestSize._get_params"
,
return_value
=
dict
(
size
=
size_sentinel
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
transform
(
inpt_sentinel
)
mock
.
assert_called_once_with
(
inpt_sentinel
,
size
=
size_sentinel
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
class
TestLinearTransformation
:
class
TestLinearTransformation
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
@@ -1260,7 +926,7 @@ class TestLinearTransformation:
...
@@ -1260,7 +926,7 @@ class TestLinearTransformation:
transform
=
transforms
.
LinearTransformation
(
m
,
v
)
transform
=
transforms
.
LinearTransformation
(
m
,
v
)
if
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
if
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
with
pytest
.
raises
(
TypeError
,
match
=
"
LinearTransformation
does not
work on
PIL
I
mages"
):
with
pytest
.
raises
(
TypeError
,
match
=
"does not
support
PIL
i
mages"
):
transform
(
inpt
)
transform
(
inpt
)
else
:
else
:
output
=
transform
(
inpt
)
output
=
transform
(
inpt
)
...
@@ -1284,30 +950,6 @@ class TestRandomResize:
...
@@ -1284,30 +950,6 @@ class TestRandomResize:
assert
min_size
<=
size
<
max_size
assert
min_size
<=
size
<
max_size
def
test__transform
(
self
,
mocker
):
interpolation_sentinel
=
mocker
.
MagicMock
(
spec
=
InterpolationMode
)
antialias_sentinel
=
mocker
.
MagicMock
()
transform
=
transforms
.
RandomResize
(
min_size
=-
1
,
max_size
=-
1
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
transform
.
_transformed_types
=
(
mocker
.
MagicMock
,)
size_sentinel
=
mocker
.
MagicMock
()
mocker
.
patch
(
"torchvision.transforms.v2._geometry.RandomResize._get_params"
,
return_value
=
dict
(
size
=
size_sentinel
),
)
inpt_sentinel
=
mocker
.
MagicMock
()
mock_resize
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
transform
(
inpt_sentinel
)
mock_resize
.
assert_called_with
(
inpt_sentinel
,
size_sentinel
,
interpolation
=
interpolation_sentinel
,
antialias
=
antialias_sentinel
)
class
TestUniformTemporalSubsample
:
class
TestUniformTemporalSubsample
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
test/test_transforms_v2_consistency.py
View file @
db310636
...
@@ -1259,68 +1259,6 @@ class TestRefSegTransforms:
...
@@ -1259,68 +1259,6 @@ class TestRefSegTransforms:
def
test_common
(
self
,
t_ref
,
t
,
data_kwargs
):
def
test_common
(
self
,
t_ref
,
t
,
data_kwargs
):
self
.
check
(
t
,
t_ref
,
data_kwargs
)
self
.
check
(
t
,
t_ref
,
data_kwargs
)
def
check_resize
(
self
,
mocker
,
t_ref
,
t
):
mock
=
mocker
.
patch
(
"torchvision.transforms.v2._geometry.F.resize"
)
mock_ref
=
mocker
.
patch
(
"torchvision.transforms.functional.resize"
)
for
dp
,
dp_ref
in
self
.
make_datapoints
():
mock
.
reset_mock
()
mock_ref
.
reset_mock
()
self
.
set_seed
()
t
(
dp
)
assert
mock
.
call_count
==
2
assert
all
(
actual
is
expected
for
actual
,
expected
in
zip
([
call_args
[
0
][
0
]
for
call_args
in
mock
.
call_args_list
],
dp
)
)
self
.
set_seed
()
t_ref
(
*
dp_ref
)
assert
mock_ref
.
call_count
==
2
assert
all
(
actual
is
expected
for
actual
,
expected
in
zip
([
call_args
[
0
][
0
]
for
call_args
in
mock_ref
.
call_args_list
],
dp_ref
)
)
for
args_kwargs
,
args_kwargs_ref
in
zip
(
mock
.
call_args_list
,
mock_ref
.
call_args_list
):
assert
args_kwargs
[
0
][
1
]
==
[
args_kwargs_ref
[
0
][
1
]]
def
test_random_resize_train
(
self
,
mocker
):
base_size
=
520
min_size
=
base_size
//
2
max_size
=
base_size
*
2
randint
=
torch
.
randint
def
patched_randint
(
a
,
b
,
*
other_args
,
**
kwargs
):
if
kwargs
or
len
(
other_args
)
>
1
or
other_args
[
0
]
!=
():
return
randint
(
a
,
b
,
*
other_args
,
**
kwargs
)
return
random
.
randint
(
a
,
b
)
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t
=
v2_transforms
.
RandomResize
(
min_size
=
min_size
,
max_size
=
max_size
,
antialias
=
True
)
mocker
.
patch
(
"torchvision.transforms.v2._geometry.torch.randint"
,
new
=
patched_randint
,
)
t_ref
=
seg_transforms
.
RandomResize
(
min_size
=
min_size
,
max_size
=
max_size
)
self
.
check_resize
(
mocker
,
t_ref
,
t
)
def
test_random_resize_eval
(
self
,
mocker
):
torch
.
manual_seed
(
0
)
base_size
=
520
t
=
v2_transforms
.
Resize
(
size
=
base_size
,
antialias
=
True
)
t_ref
=
seg_transforms
.
RandomResize
(
min_size
=
base_size
,
max_size
=
base_size
)
self
.
check_resize
(
mocker
,
t_ref
,
t
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"legacy_dispatcher"
,
"name_only_params"
),
(
"legacy_dispatcher"
,
"name_only_params"
),
...
...
test/test_transforms_v2_refactored.py
View file @
db310636
...
@@ -39,7 +39,7 @@ from torchvision import datapoints
...
@@ -39,7 +39,7 @@ from torchvision import datapoints
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional
import
pil_modes_mapping
from
torchvision.transforms.functional
import
pil_modes_mapping
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._utils
import
_get_kernel
,
_KERNEL_REGISTRY
,
_noop
,
_register_kernel_internal
from
torchvision.transforms.v2.functional._utils
import
_get_kernel
,
_register_kernel_internal
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -376,35 +376,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -376,35 +376,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
return
torch
.
stack
([
transform
(
b
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()]).
reshape
(
bounding_boxes
.
shape
)
return
torch
.
stack
([
transform
(
b
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()]).
reshape
(
bounding_boxes
.
shape
)
@
pytest
.
mark
.
parametrize
(
(
"dispatcher"
,
"registered_input_types"
),
[(
dispatcher
,
set
(
registry
.
keys
()))
for
dispatcher
,
registry
in
_KERNEL_REGISTRY
.
items
()],
)
def
test_exhaustive_kernel_registration
(
dispatcher
,
registered_input_types
):
missing
=
{
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
,
}
-
registered_input_types
if
missing
:
names
=
sorted
(
str
(
t
)
for
t
in
missing
)
raise
AssertionError
(
"
\n
"
.
join
(
[
f
"The dispatcher '
{
dispatcher
.
__name__
}
' has no kernel registered for"
,
""
,
*
[
f
"-
{
name
}
"
for
name
in
names
],
""
,
f
"If available, register the kernels with @_register_kernel_internal(
{
dispatcher
.
__name__
}
, ...)."
,
f
"If not, register explicit no-ops with @_register_explicit_noop(
{
', '
.
join
(
names
)
}
)"
,
]
)
)
class
TestResize
:
class
TestResize
:
INPUT_SIZE
=
(
17
,
11
)
INPUT_SIZE
=
(
17
,
11
)
OUTPUT_SIZES
=
[
17
,
[
17
],
(
17
,),
[
12
,
13
],
(
12
,
13
)]
OUTPUT_SIZES
=
[
17
,
[
17
],
(
17
,),
[
12
,
13
],
(
12
,
13
)]
...
@@ -2128,9 +2099,20 @@ class TestRegisterKernel:
...
@@ -2128,9 +2099,20 @@ class TestRegisterKernel:
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
F
.
register_kernel
(
F
.
resize
,
object
)
F
.
register_kernel
(
F
.
resize
,
object
)
with
pytest
.
raises
(
ValueError
,
match
=
"
already has a kernel registered for type
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
cannot be registered for the builtin datapoint classes
"
):
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image_tensor
)
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image_tensor
)
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
def
resize_custom_datapoint
():
pass
F
.
register_kernel
(
F
.
resize
,
CustomDatapoint
)(
resize_custom_datapoint
)
with
pytest
.
raises
(
ValueError
,
match
=
"already has a kernel registered for type"
):
F
.
register_kernel
(
F
.
resize
,
CustomDatapoint
)(
resize_custom_datapoint
)
class
TestGetKernel
:
class
TestGetKernel
:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
...
@@ -2152,13 +2134,7 @@ class TestGetKernel:
...
@@ -2152,13 +2134,7 @@ class TestGetKernel:
pass
pass
for
input_type
in
[
str
,
int
,
object
,
MyTensor
,
MyPILImage
]:
for
input_type
in
[
str
,
int
,
object
,
MyTensor
,
MyPILImage
]:
with
pytest
.
raises
(
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
TypeError
,
match
=
(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
_get_kernel
(
F
.
resize
,
input_type
)
_get_kernel
(
F
.
resize
,
input_type
)
def
test_exact_match
(
self
):
def
test_exact_match
(
self
):
...
@@ -2211,8 +2187,8 @@ class TestGetKernel:
...
@@ -2211,8 +2187,8 @@ class TestGetKernel:
class
MyDatapoint
(
datapoints
.
Datapoint
):
class
MyDatapoint
(
datapoints
.
Datapoint
):
pass
pass
# Note that this will be an error in the future
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
_noop
_get_kernel
(
F
.
resize
,
MyDatapoint
)
def
resize_my_datapoint
():
def
resize_my_datapoint
():
pass
pass
...
...
torchvision/prototype/transforms/_geometry.py
View file @
db310636
...
@@ -101,7 +101,8 @@ class FixedSizeCrop(Transform):
...
@@ -101,7 +101,8 @@ class FixedSizeCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_crop"
]:
if
params
[
"needs_crop"
]:
inpt
=
F
.
crop
(
inpt
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
inpt
,
top
=
params
[
"top"
],
top
=
params
[
"top"
],
left
=
params
[
"left"
],
left
=
params
[
"left"
],
...
@@ -120,6 +121,6 @@ class FixedSizeCrop(Transform):
...
@@ -120,6 +121,6 @@ class FixedSizeCrop(Transform):
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
inpt
=
F
.
pad
(
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
self
.
_call_kernel
(
F
.
pad
,
inpt
,
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
return
inpt
return
inpt
torchvision/transforms/v2/_augment.py
View file @
db310636
import
math
import
math
import
numbers
import
numbers
import
warnings
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -91,6 +91,14 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -91,6 +91,14 @@ class RandomErasing(_RandomApplyTransform):
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
img_c
,
img_h
,
img_w
=
query_chw
(
flat_inputs
)
img_c
,
img_h
,
img_w
=
query_chw
(
flat_inputs
)
...
@@ -131,7 +139,7 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -131,7 +139,7 @@ class RandomErasing(_RandomApplyTransform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"v"
]
is
not
None
:
if
params
[
"v"
]
is
not
None
:
inpt
=
F
.
erase
(
inpt
,
**
params
,
inplace
=
self
.
inplace
)
inpt
=
self
.
_call_kernel
(
F
.
erase
,
inpt
,
**
params
,
inplace
=
self
.
inplace
)
return
inpt
return
inpt
...
...
torchvision/transforms/v2/_color.py
View file @
db310636
import
collections.abc
import
collections.abc
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
PIL.Image
import
torch
import
torch
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
transforms
as
_transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._transform
import
_RandomApplyTransform
from
._transform
import
_RandomApplyTransform
from
.utils
import
is_simple_tensor
,
query_chw
from
.utils
import
query_chw
class
Grayscale
(
Transform
):
class
Grayscale
(
Transform
):
...
@@ -24,19 +23,12 @@ class Grayscale(Transform):
...
@@ -24,19 +23,12 @@ class Grayscale(Transform):
_v1_transform_cls
=
_transforms
.
Grayscale
_v1_transform_cls
=
_transforms
.
Grayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
num_output_channels
:
int
=
1
):
def
__init__
(
self
,
num_output_channels
:
int
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
num_output_channels
=
num_output_channels
self
.
num_output_channels
=
num_output_channels
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
self
.
num_output_channels
)
return
self
.
_call_kernel
(
F
.
rgb_to_grayscale
,
inpt
,
num_output_channels
=
self
.
num_output_channels
)
class
RandomGrayscale
(
_RandomApplyTransform
):
class
RandomGrayscale
(
_RandomApplyTransform
):
...
@@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform):
...
@@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomGrayscale
_v1_transform_cls
=
_transforms
.
RandomGrayscale
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
p
:
float
=
0.1
)
->
None
:
def
__init__
(
self
,
p
:
float
=
0.1
)
->
None
:
super
().
__init__
(
p
=
p
)
super
().
__init__
(
p
=
p
)
...
@@ -70,7 +55,7 @@ class RandomGrayscale(_RandomApplyTransform):
...
@@ -70,7 +55,7 @@ class RandomGrayscale(_RandomApplyTransform):
return
dict
(
num_input_channels
=
num_input_channels
)
return
dict
(
num_input_channels
=
num_input_channels
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
rgb_to_grayscale
(
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
return
self
.
_call_kernel
(
F
.
rgb_to_grayscale
,
inpt
,
num_output_channels
=
params
[
"num_input_channels"
])
class
ColorJitter
(
Transform
):
class
ColorJitter
(
Transform
):
...
@@ -167,13 +152,13 @@ class ColorJitter(Transform):
...
@@ -167,13 +152,13 @@ class ColorJitter(Transform):
hue_factor
=
params
[
"hue_factor"
]
hue_factor
=
params
[
"hue_factor"
]
for
fn_id
in
params
[
"fn_idx"
]:
for
fn_id
in
params
[
"fn_idx"
]:
if
fn_id
==
0
and
brightness_factor
is
not
None
:
if
fn_id
==
0
and
brightness_factor
is
not
None
:
output
=
F
.
adjust_brightness
(
output
,
brightness_factor
=
brightness_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_brightness
,
output
,
brightness_factor
=
brightness_factor
)
elif
fn_id
==
1
and
contrast_factor
is
not
None
:
elif
fn_id
==
1
and
contrast_factor
is
not
None
:
output
=
F
.
adjust_contrast
(
output
,
contrast_factor
=
contrast_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_contrast
,
output
,
contrast_factor
=
contrast_factor
)
elif
fn_id
==
2
and
saturation_factor
is
not
None
:
elif
fn_id
==
2
and
saturation_factor
is
not
None
:
output
=
F
.
adjust_saturation
(
output
,
saturation_factor
=
saturation_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_saturation
,
output
,
saturation_factor
=
saturation_factor
)
elif
fn_id
==
3
and
hue_factor
is
not
None
:
elif
fn_id
==
3
and
hue_factor
is
not
None
:
output
=
F
.
adjust_hue
(
output
,
hue_factor
=
hue_factor
)
output
=
self
.
_call_kernel
(
F
.
adjust_hue
,
output
,
hue_factor
=
hue_factor
)
return
output
return
output
...
@@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform):
...
@@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform):
.. v2betastatus:: RandomChannelPermutation transform
.. v2betastatus:: RandomChannelPermutation transform
"""
"""
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
num_channels
,
*
_
=
query_chw
(
flat_inputs
)
num_channels
,
*
_
=
query_chw
(
flat_inputs
)
return
dict
(
permutation
=
torch
.
randperm
(
num_channels
))
return
dict
(
permutation
=
torch
.
randperm
(
num_channels
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
permute_channels
(
inpt
,
params
[
"permutation"
])
return
self
.
_call_kernel
(
F
.
permute_channels
,
inpt
,
params
[
"permutation"
])
class
RandomPhotometricDistort
(
Transform
):
class
RandomPhotometricDistort
(
Transform
):
...
@@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform):
...
@@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform):
Default is 0.5.
Default is 0.5.
"""
"""
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
def
__init__
(
self
,
self
,
brightness
:
Tuple
[
float
,
float
]
=
(
0.875
,
1.125
),
brightness
:
Tuple
[
float
,
float
]
=
(
0.875
,
1.125
),
...
@@ -263,17 +234,17 @@ class RandomPhotometricDistort(Transform):
...
@@ -263,17 +234,17 @@ class RandomPhotometricDistort(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"brightness_factor"
]
is
not
None
:
if
params
[
"brightness_factor"
]
is
not
None
:
inpt
=
F
.
adjust_brightness
(
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_brightness
,
inpt
,
brightness_factor
=
params
[
"brightness_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
if
params
[
"contrast_factor"
]
is
not
None
and
params
[
"contrast_before"
]:
inpt
=
F
.
adjust_contrast
(
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_contrast
,
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
if
params
[
"saturation_factor"
]
is
not
None
:
if
params
[
"saturation_factor"
]
is
not
None
:
inpt
=
F
.
adjust_saturation
(
inpt
,
saturation_factor
=
params
[
"saturation_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_saturation
,
inpt
,
saturation_factor
=
params
[
"saturation_factor"
])
if
params
[
"hue_factor"
]
is
not
None
:
if
params
[
"hue_factor"
]
is
not
None
:
inpt
=
F
.
adjust_hue
(
inpt
,
hue_factor
=
params
[
"hue_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_hue
,
inpt
,
hue_factor
=
params
[
"hue_factor"
])
if
params
[
"contrast_factor"
]
is
not
None
and
not
params
[
"contrast_before"
]:
if
params
[
"contrast_factor"
]
is
not
None
and
not
params
[
"contrast_before"
]:
inpt
=
F
.
adjust_contrast
(
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
inpt
=
self
.
_call_kernel
(
F
.
adjust_contrast
,
inpt
,
contrast_factor
=
params
[
"contrast_factor"
])
if
params
[
"channel_permutation"
]
is
not
None
:
if
params
[
"channel_permutation"
]
is
not
None
:
inpt
=
F
.
permute_channels
(
inpt
,
permutation
=
params
[
"channel_permutation"
])
inpt
=
self
.
_call_kernel
(
F
.
permute_channels
,
inpt
,
permutation
=
params
[
"channel_permutation"
])
return
inpt
return
inpt
...
@@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform):
...
@@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomEqualize
_v1_transform_cls
=
_transforms
.
RandomEqualize
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
equalize
(
inpt
)
return
self
.
_call_kernel
(
F
.
equalize
,
inpt
)
class
RandomInvert
(
_RandomApplyTransform
):
class
RandomInvert
(
_RandomApplyTransform
):
...
@@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform):
...
@@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomInvert
_v1_transform_cls
=
_transforms
.
RandomInvert
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
invert
(
inpt
)
return
self
.
_call_kernel
(
F
.
invert
,
inpt
)
class
RandomPosterize
(
_RandomApplyTransform
):
class
RandomPosterize
(
_RandomApplyTransform
):
...
@@ -337,7 +308,7 @@ class RandomPosterize(_RandomApplyTransform):
...
@@ -337,7 +308,7 @@ class RandomPosterize(_RandomApplyTransform):
self
.
bits
=
bits
self
.
bits
=
bits
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
posterize
(
inpt
,
bits
=
self
.
bits
)
return
self
.
_call_kernel
(
F
.
posterize
,
inpt
,
bits
=
self
.
bits
)
class
RandomSolarize
(
_RandomApplyTransform
):
class
RandomSolarize
(
_RandomApplyTransform
):
...
@@ -362,7 +333,7 @@ class RandomSolarize(_RandomApplyTransform):
...
@@ -362,7 +333,7 @@ class RandomSolarize(_RandomApplyTransform):
self
.
threshold
=
threshold
self
.
threshold
=
threshold
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
solarize
(
inpt
,
threshold
=
self
.
threshold
)
return
self
.
_call_kernel
(
F
.
solarize
,
inpt
,
threshold
=
self
.
threshold
)
class
RandomAutocontrast
(
_RandomApplyTransform
):
class
RandomAutocontrast
(
_RandomApplyTransform
):
...
@@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform):
...
@@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomAutocontrast
_v1_transform_cls
=
_transforms
.
RandomAutocontrast
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
autocontrast
(
inpt
)
return
self
.
_call_kernel
(
F
.
autocontrast
,
inpt
)
class
RandomAdjustSharpness
(
_RandomApplyTransform
):
class
RandomAdjustSharpness
(
_RandomApplyTransform
):
...
@@ -406,4 +377,4 @@ class RandomAdjustSharpness(_RandomApplyTransform):
...
@@ -406,4 +377,4 @@ class RandomAdjustSharpness(_RandomApplyTransform):
self
.
sharpness_factor
=
sharpness_factor
self
.
sharpness_factor
=
sharpness_factor
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
adjust_sharpness
(
inpt
,
sharpness_factor
=
self
.
sharpness_factor
)
return
self
.
_call_kernel
(
F
.
adjust_sharpness
,
inpt
,
sharpness_factor
=
self
.
sharpness_factor
)
torchvision/transforms/v2/_geometry.py
View file @
db310636
import
math
import
math
import
numbers
import
numbers
import
warnings
import
warnings
from
typing
import
Any
,
cast
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
...
@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomHorizontalFlip
_v1_transform_cls
=
_transforms
.
RandomHorizontalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
horizontal_flip
(
inpt
)
return
self
.
_call_kernel
(
F
.
horizontal_flip
,
inpt
)
class
RandomVerticalFlip
(
_RandomApplyTransform
):
class
RandomVerticalFlip
(
_RandomApplyTransform
):
...
@@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
...
@@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls
=
_transforms
.
RandomVerticalFlip
_v1_transform_cls
=
_transforms
.
RandomVerticalFlip
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
vertical_flip
(
inpt
)
return
self
.
_call_kernel
(
F
.
vertical_flip
,
inpt
)
class
Resize
(
Transform
):
class
Resize
(
Transform
):
...
@@ -152,7 +152,8 @@ class Resize(Transform):
...
@@ -152,7 +152,8 @@ class Resize(Transform):
self
.
antialias
=
antialias
self
.
antialias
=
antialias
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
inpt
,
self
.
size
,
self
.
size
,
interpolation
=
self
.
interpolation
,
interpolation
=
self
.
interpolation
,
...
@@ -186,7 +187,7 @@ class CenterCrop(Transform):
...
@@ -186,7 +187,7 @@ class CenterCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
center_crop
(
inpt
,
output_size
=
self
.
size
)
return
self
.
_call_kernel
(
F
.
center_crop
,
inpt
,
output_size
=
self
.
size
)
class
RandomResizedCrop
(
Transform
):
class
RandomResizedCrop
(
Transform
):
...
@@ -307,8 +308,8 @@ class RandomResizedCrop(Transform):
...
@@ -307,8 +308,8 @@ class RandomResizedCrop(Transform):
return
dict
(
top
=
i
,
left
=
j
,
height
=
h
,
width
=
w
)
return
dict
(
top
=
i
,
left
=
j
,
height
=
h
,
width
=
w
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resized_crop
(
return
self
.
_call_kernel
(
inpt
,
**
params
,
size
=
self
.
size
,
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
F
.
resized_crop
,
inpt
,
**
params
,
size
=
self
.
size
,
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
)
...
@@ -357,8 +358,16 @@ class FiveCrop(Transform):
...
@@ -357,8 +358,16 @@ class FiveCrop(Transform):
super
().
__init__
()
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
five_crop
(
inpt
,
self
.
size
)
return
self
.
_call_kernel
(
F
.
five_crop
,
inpt
,
self
.
size
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
...
@@ -396,12 +405,20 @@ class TenCrop(Transform):
...
@@ -396,12 +405,20 @@ class TenCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
vertical_flip
=
vertical_flip
self
.
vertical_flip
=
vertical_flip
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
raise
TypeError
(
f
"BoundingBoxes'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
raise
TypeError
(
f
"BoundingBoxes'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
ten_crop
(
inpt
,
self
.
size
,
vertical_flip
=
self
.
vertical_flip
)
return
self
.
_call_kernel
(
F
.
ten_crop
,
inpt
,
self
.
size
,
vertical_flip
=
self
.
vertical_flip
)
class
Pad
(
Transform
):
class
Pad
(
Transform
):
...
@@ -475,7 +492,7 @@ class Pad(Transform):
...
@@ -475,7 +492,7 @@ class Pad(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
pad
(
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
return
self
.
_call_kernel
(
F
.
pad
,
inpt
,
padding
=
self
.
padding
,
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
# type: ignore[arg-type]
class
RandomZoomOut
(
_RandomApplyTransform
):
class
RandomZoomOut
(
_RandomApplyTransform
):
...
@@ -545,7 +562,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -545,7 +562,7 @@ class RandomZoomOut(_RandomApplyTransform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
pad
(
inpt
,
**
params
,
fill
=
fill
)
return
self
.
_call_kernel
(
F
.
pad
,
inpt
,
**
params
,
fill
=
fill
)
class
RandomRotation
(
Transform
):
class
RandomRotation
(
Transform
):
...
@@ -611,7 +628,8 @@ class RandomRotation(Transform):
...
@@ -611,7 +628,8 @@ class RandomRotation(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
rotate
(
return
self
.
_call_kernel
(
F
.
rotate
,
inpt
,
inpt
,
**
params
,
**
params
,
interpolation
=
self
.
interpolation
,
interpolation
=
self
.
interpolation
,
...
@@ -733,7 +751,8 @@ class RandomAffine(Transform):
...
@@ -733,7 +751,8 @@ class RandomAffine(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
affine
(
return
self
.
_call_kernel
(
F
.
affine
,
inpt
,
inpt
,
**
params
,
**
params
,
interpolation
=
self
.
interpolation
,
interpolation
=
self
.
interpolation
,
...
@@ -889,10 +908,12 @@ class RandomCrop(Transform):
...
@@ -889,10 +908,12 @@ class RandomCrop(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"needs_pad"
]:
if
params
[
"needs_pad"
]:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
inpt
=
F
.
pad
(
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
inpt
=
self
.
_call_kernel
(
F
.
pad
,
inpt
,
padding
=
params
[
"padding"
],
fill
=
fill
,
padding_mode
=
self
.
padding_mode
)
if
params
[
"needs_crop"
]:
if
params
[
"needs_crop"
]:
inpt
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
inpt
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]
)
return
inpt
return
inpt
...
@@ -973,7 +994,8 @@ class RandomPerspective(_RandomApplyTransform):
...
@@ -973,7 +994,8 @@ class RandomPerspective(_RandomApplyTransform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
perspective
(
return
self
.
_call_kernel
(
F
.
perspective
,
inpt
,
inpt
,
None
,
None
,
None
,
None
,
...
@@ -1050,7 +1072,7 @@ class ElasticTransform(Transform):
...
@@ -1050,7 +1072,7 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd
# if kernel size is even we have to make it odd
if
kx
%
2
==
0
:
if
kx
%
2
==
0
:
kx
+=
1
kx
+=
1
dx
=
F
.
gaussian_blur
(
dx
,
[
kx
,
kx
],
list
(
self
.
sigma
))
dx
=
self
.
_call_kernel
(
F
.
gaussian_blur
,
dx
,
[
kx
,
kx
],
list
(
self
.
sigma
))
dx
=
dx
*
self
.
alpha
[
0
]
/
size
[
0
]
dx
=
dx
*
self
.
alpha
[
0
]
/
size
[
0
]
dy
=
torch
.
rand
([
1
,
1
]
+
size
)
*
2
-
1
dy
=
torch
.
rand
([
1
,
1
]
+
size
)
*
2
-
1
...
@@ -1059,14 +1081,15 @@ class ElasticTransform(Transform):
...
@@ -1059,14 +1081,15 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd
# if kernel size is even we have to make it odd
if
ky
%
2
==
0
:
if
ky
%
2
==
0
:
ky
+=
1
ky
+=
1
dy
=
F
.
gaussian_blur
(
dy
,
[
ky
,
ky
],
list
(
self
.
sigma
))
dy
=
self
.
_call_kernel
(
F
.
gaussian_blur
,
dy
,
[
ky
,
ky
],
list
(
self
.
sigma
))
dy
=
dy
*
self
.
alpha
[
1
]
/
size
[
1
]
dy
=
dy
*
self
.
alpha
[
1
]
/
size
[
1
]
displacement
=
torch
.
concat
([
dx
,
dy
],
1
).
permute
([
0
,
2
,
3
,
1
])
# 1 x H x W x 2
displacement
=
torch
.
concat
([
dx
,
dy
],
1
).
permute
([
0
,
2
,
3
,
1
])
# 1 x H x W x 2
return
dict
(
displacement
=
displacement
)
return
dict
(
displacement
=
displacement
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
fill
=
_get_fill
(
self
.
_fill
,
type
(
inpt
))
return
F
.
elastic
(
return
self
.
_call_kernel
(
F
.
elastic
,
inpt
,
inpt
,
**
params
,
**
params
,
fill
=
fill
,
fill
=
fill
,
...
@@ -1164,7 +1187,9 @@ class RandomIoUCrop(Transform):
...
@@ -1164,7 +1187,9 @@ class RandomIoUCrop(Transform):
# check for any valid boxes with centers within the crop area
# check for any valid boxes with centers within the crop area
xyxy_bboxes
=
F
.
convert_format_bounding_boxes
(
xyxy_bboxes
=
F
.
convert_format_bounding_boxes
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
)
)
cx
=
0.5
*
(
xyxy_bboxes
[...,
0
]
+
xyxy_bboxes
[...,
2
])
cx
=
0.5
*
(
xyxy_bboxes
[...,
0
]
+
xyxy_bboxes
[...,
2
])
cy
=
0.5
*
(
xyxy_bboxes
[...,
1
]
+
xyxy_bboxes
[...,
3
])
cy
=
0.5
*
(
xyxy_bboxes
[...,
1
]
+
xyxy_bboxes
[...,
3
])
...
@@ -1188,7 +1213,9 @@ class RandomIoUCrop(Transform):
...
@@ -1188,7 +1213,9 @@ class RandomIoUCrop(Transform):
if
len
(
params
)
<
1
:
if
len
(
params
)
<
1
:
return
inpt
return
inpt
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
output
=
self
.
_call_kernel
(
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]
)
if
isinstance
(
output
,
datapoints
.
BoundingBoxes
):
if
isinstance
(
output
,
datapoints
.
BoundingBoxes
):
# We "mark" the invalid boxes as degenreate, and they can be
# We "mark" the invalid boxes as degenreate, and they can be
...
@@ -1262,7 +1289,9 @@ class ScaleJitter(Transform):
...
@@ -1262,7 +1289,9 @@ class ScaleJitter(Transform):
return
dict
(
size
=
(
new_height
,
new_width
))
return
dict
(
size
=
(
new_height
,
new_width
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
RandomShortestSize
(
Transform
):
class
RandomShortestSize
(
Transform
):
...
@@ -1330,7 +1359,9 @@ class RandomShortestSize(Transform):
...
@@ -1330,7 +1359,9 @@ class RandomShortestSize(Transform):
return
dict
(
size
=
(
new_height
,
new_width
))
return
dict
(
size
=
(
new_height
,
new_width
))
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
size
=
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
RandomResize
(
Transform
):
class
RandomResize
(
Transform
):
...
@@ -1400,4 +1431,6 @@ class RandomResize(Transform):
...
@@ -1400,4 +1431,6 @@ class RandomResize(Transform):
return
dict
(
size
=
[
size
])
return
dict
(
size
=
[
size
])
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
resize
(
inpt
,
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
return
self
.
_call_kernel
(
F
.
resize
,
inpt
,
params
[
"size"
],
interpolation
=
self
.
interpolation
,
antialias
=
self
.
antialias
)
torchvision/transforms/v2/_misc.py
View file @
db310636
...
@@ -106,7 +106,7 @@ class LinearTransformation(Transform):
...
@@ -106,7 +106,7 @@ class LinearTransformation(Transform):
def
_check_inputs
(
self
,
sample
:
Any
)
->
Any
:
def
_check_inputs
(
self
,
sample
:
Any
)
->
Any
:
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
if
has_any
(
sample
,
PIL
.
Image
.
Image
):
raise
TypeError
(
"LinearTransformation
does not
work on
PIL
I
mages"
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
()
does not
support
PIL
i
mages
.
"
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
shape
=
inpt
.
shape
shape
=
inpt
.
shape
...
@@ -157,7 +157,6 @@ class Normalize(Transform):
...
@@ -157,7 +157,6 @@ class Normalize(Transform):
"""
"""
_v1_transform_cls
=
_transforms
.
Normalize
_v1_transform_cls
=
_transforms
.
Normalize
_transformed_types
=
(
datapoints
.
Image
,
is_simple_tensor
,
datapoints
.
Video
)
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
],
inplace
:
bool
=
False
):
def
__init__
(
self
,
mean
:
Sequence
[
float
],
std
:
Sequence
[
float
],
inplace
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
...
@@ -170,7 +169,7 @@ class Normalize(Transform):
...
@@ -170,7 +169,7 @@ class Normalize(Transform):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images."
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
normalize
(
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
return
self
.
_call_kernel
(
F
.
normalize
,
inpt
,
mean
=
self
.
mean
,
std
=
self
.
std
,
inplace
=
self
.
inplace
)
class
GaussianBlur
(
Transform
):
class
GaussianBlur
(
Transform
):
...
@@ -217,7 +216,7 @@ class GaussianBlur(Transform):
...
@@ -217,7 +216,7 @@ class GaussianBlur(Transform):
return
dict
(
sigma
=
[
sigma
,
sigma
])
return
dict
(
sigma
=
[
sigma
,
sigma
])
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
gaussian_blur
(
inpt
,
self
.
kernel_size
,
**
params
)
return
self
.
_call_kernel
(
F
.
gaussian_blur
,
inpt
,
self
.
kernel_size
,
**
params
)
class
ToDtype
(
Transform
):
class
ToDtype
(
Transform
):
...
@@ -290,7 +289,7 @@ class ToDtype(Transform):
...
@@ -290,7 +289,7 @@ class ToDtype(Transform):
)
)
return
inpt
return
inpt
return
F
.
to_dtype
(
inpt
,
dtype
=
dtype
,
scale
=
self
.
scale
)
return
self
.
_call_kernel
(
F
.
to_dtype
,
inpt
,
dtype
=
dtype
,
scale
=
self
.
scale
)
class
ConvertImageDtype
(
Transform
):
class
ConvertImageDtype
(
Transform
):
...
@@ -320,14 +319,12 @@ class ConvertImageDtype(Transform):
...
@@ -320,14 +319,12 @@ class ConvertImageDtype(Transform):
_v1_transform_cls
=
_transforms
.
ConvertImageDtype
_v1_transform_cls
=
_transforms
.
ConvertImageDtype
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
)
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dtype
=
dtype
self
.
dtype
=
dtype
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
to_dtype
(
inpt
,
dtype
=
self
.
dtype
,
scale
=
True
)
return
self
.
_call_kernel
(
F
.
to_dtype
,
inpt
,
dtype
=
self
.
dtype
,
scale
=
True
)
class
SanitizeBoundingBoxes
(
Transform
):
class
SanitizeBoundingBoxes
(
Transform
):
...
...
torchvision/transforms/v2/_temporal.py
View file @
db310636
...
@@ -25,4 +25,4 @@ class UniformTemporalSubsample(Transform):
...
@@ -25,4 +25,4 @@ class UniformTemporalSubsample(Transform):
self
.
num_samples
=
num_samples
self
.
num_samples
=
num_samples
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
uniform_temporal_subsample
(
inpt
,
self
.
num_samples
)
return
self
.
_call_kernel
(
F
.
uniform_temporal_subsample
,
inpt
,
self
.
num_samples
)
torchvision/transforms/v2/_transform.py
View file @
db310636
...
@@ -11,6 +11,8 @@ from torchvision import datapoints
...
@@ -11,6 +11,8 @@ from torchvision import datapoints
from
torchvision.transforms.v2.utils
import
check_type
,
has_any
,
is_simple_tensor
from
torchvision.transforms.v2.utils
import
check_type
,
has_any
,
is_simple_tensor
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
.functional._utils
import
_get_kernel
class
Transform
(
nn
.
Module
):
class
Transform
(
nn
.
Module
):
...
@@ -28,6 +30,10 @@ class Transform(nn.Module):
...
@@ -28,6 +30,10 @@ class Transform(nn.Module):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
return
dict
()
return
dict
()
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
kernel
=
_get_kernel
(
dispatcher
,
type
(
inpt
),
allow_passthrough
=
True
)
return
kernel
(
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
raise
NotImplementedError
raise
NotImplementedError
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
db310636
...
@@ -5,10 +5,9 @@ from torchvision import datapoints
...
@@ -5,10 +5,9 @@ from torchvision import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_kernel_internal
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
def
erase
(
def
erase
(
inpt
:
torch
.
Tensor
,
inpt
:
torch
.
Tensor
,
i
:
int
,
i
:
int
,
...
...
torchvision/transforms/v2/functional/_color.py
View file @
db310636
...
@@ -10,12 +10,10 @@ from torchvision.transforms._functional_tensor import _max_value
...
@@ -10,12 +10,10 @@ from torchvision.transforms._functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._type_conversion
import
pil_to_tensor
,
to_image_pil
from
._type_conversion
import
pil_to_tensor
,
to_image_pil
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_kernel_internal
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
def
rgb_to_grayscale
(
inpt
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
...
@@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
...
@@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
return
output
if
fp
else
output
.
to
(
image1
.
dtype
)
return
output
if
fp
else
output
.
to
(
image1
.
dtype
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_brightness
(
inpt
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
...
@@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
...
@@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return
adjust_brightness_image_tensor
(
video
,
brightness_factor
=
brightness_factor
)
return
adjust_brightness_image_tensor
(
video
,
brightness_factor
=
brightness_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_saturation
(
inpt
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
...
@@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
...
@@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return
adjust_saturation_image_tensor
(
video
,
saturation_factor
=
saturation_factor
)
return
adjust_saturation_image_tensor
(
video
,
saturation_factor
=
saturation_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_contrast
(
inpt
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
...
@@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
...
@@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return
adjust_contrast_image_tensor
(
video
,
contrast_factor
=
contrast_factor
)
return
adjust_contrast_image_tensor
(
video
,
contrast_factor
=
contrast_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_sharpness
(
inpt
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
...
@@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
...
@@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return
adjust_sharpness_image_tensor
(
video
,
sharpness_factor
=
sharpness_factor
)
return
adjust_sharpness_image_tensor
(
video
,
sharpness_factor
=
sharpness_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
def
adjust_hue
(
inpt
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
...
@@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
...
@@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return
adjust_hue_image_tensor
(
video
,
hue_factor
=
hue_factor
)
return
adjust_hue_image_tensor
(
video
,
hue_factor
=
hue_factor
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
def
adjust_gamma
(
inpt
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
...
@@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
...
@@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return
adjust_gamma_image_tensor
(
video
,
gamma
=
gamma
,
gain
=
gain
)
return
adjust_gamma_image_tensor
(
video
,
gamma
=
gamma
,
gain
=
gain
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
def
posterize
(
inpt
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
...
@@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
...
@@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return
posterize_image_tensor
(
video
,
bits
=
bits
)
return
posterize_image_tensor
(
video
,
bits
=
bits
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
def
solarize
(
inpt
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
...
@@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
...
@@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return
solarize_image_tensor
(
video
,
threshold
=
threshold
)
return
solarize_image_tensor
(
video
,
threshold
=
threshold
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
autocontrast
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image_tensor
(
inpt
)
return
autocontrast_image_tensor
(
inpt
)
...
@@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return
autocontrast_image_tensor
(
video
)
return
autocontrast_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
equalize
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
equalize_image_tensor
(
inpt
)
return
equalize_image_tensor
(
inpt
)
...
@@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return
equalize_image_tensor
(
video
)
return
equalize_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
invert
(
inpt
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
invert_image_tensor
(
inpt
)
return
invert_image_tensor
(
inpt
)
...
@@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return
invert_image_tensor
(
video
)
return
invert_image_tensor
(
video
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
def
permute_channels
(
inpt
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
"""Permute the channels of the input according to the given permutation.
"""Permute the channels of the input according to the given permutation.
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
db310636
...
@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
...
@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
(
from
._utils
import
_FillTypeJIT
,
_get_kernel
,
_register_five_ten_crop_kernel_internal
,
_register_kernel_internal
_FillTypeJIT
,
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
)
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
@@ -2203,7 +2197,6 @@ def resized_crop_video(
...
@@ -2203,7 +2197,6 @@ def resized_crop_video(
)
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
five_crop
(
def
five_crop
(
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
inpt
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
...
@@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return
size
return
size
@
_register_five_ten_crop_kernel
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
datapoints
.
Image
)
def
five_crop_image_tensor
(
def
five_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
]
image
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -2250,7 +2243,7 @@ def five_crop_image_tensor(
...
@@ -2250,7 +2243,7 @@ def five_crop_image_tensor(
return
tl
,
tr
,
bl
,
br
,
center
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel
(
five_crop
,
PIL
.
Image
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
PIL
.
Image
.
Image
)
def
five_crop_image_pil
(
def
five_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
...
@@ -2269,14 +2262,13 @@ def five_crop_image_pil(
...
@@ -2269,14 +2262,13 @@ def five_crop_image_pil(
return
tl
,
tr
,
bl
,
br
,
center
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Video
)
@
_register_five_ten_crop_kernel
_internal
(
five_crop
,
datapoints
.
Video
)
def
five_crop_video
(
def
five_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
]
video
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
five_crop_image_tensor
(
video
,
size
)
return
five_crop_image_tensor
(
video
,
size
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
ten_crop
(
def
ten_crop
(
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
inpt
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
...
@@ -2300,8 +2292,8 @@ def ten_crop(
...
@@ -2300,8 +2292,8 @@ def ten_crop(
return
kernel
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
return
kernel
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
datapoints
.
Image
)
def
ten_crop_image_tensor
(
def
ten_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
...
@@ -2328,7 +2320,7 @@ def ten_crop_image_tensor(
...
@@ -2328,7 +2320,7 @@ def ten_crop_image_tensor(
return
non_flipped
+
flipped
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel
(
ten_crop
,
PIL
.
Image
.
Image
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
PIL
.
Image
.
Image
)
def
ten_crop_image_pil
(
def
ten_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
...
@@ -2355,7 +2347,7 @@ def ten_crop_image_pil(
...
@@ -2355,7 +2347,7 @@ def ten_crop_image_pil(
return
non_flipped
+
flipped
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Video
)
@
_register_five_ten_crop_kernel
_internal
(
ten_crop
,
datapoints
.
Video
)
def
ten_crop_video
(
def
ten_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
video
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
)
->
Tuple
[
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
db310636
...
@@ -8,10 +8,9 @@ from torchvision.transforms import _functional_pil as _FP
...
@@ -8,10 +8,9 @@ from torchvision.transforms import _functional_pil as _FP
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
_register_unsupported_type
,
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_kernel_internal
,
is_simple_tensor
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions
(
inpt
:
torch
.
Tensor
)
->
List
[
int
]:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image_tensor
(
inpt
)
return
get_dimensions_image_tensor
(
inpt
)
...
@@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
...
@@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
return
get_dimensions_image_tensor
(
video
)
return
get_dimensions_image_tensor
(
video
)
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
def
get_num_channels
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image_tensor
(
inpt
)
return
get_num_channels_image_tensor
(
inpt
)
...
@@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
...
@@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
return
list
(
bounding_box
.
canvas_size
)
return
list
(
bounding_box
.
canvas_size
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
torch
.
Tensor
)
->
int
:
def
get_num_frames
(
inpt
:
torch
.
Tensor
)
->
int
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
get_num_frames_video
(
inpt
)
return
get_num_frames_video
(
inpt
)
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
db310636
...
@@ -11,11 +11,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...
@@ -11,11 +11,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_
explicit_noop
,
_register_kernel_internal
,
_register_unsupported_type
from
._utils
import
_get_kernel
,
_register_
kernel_internal
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
normalize
(
def
normalize
(
inpt
:
torch
.
Tensor
,
inpt
:
torch
.
Tensor
,
mean
:
List
[
float
],
mean
:
List
[
float
],
...
@@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
...
@@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
return
normalize_image_tensor
(
video
,
mean
,
std
,
inplace
=
inplace
)
return
normalize_image_tensor
(
video
,
mean
,
std
,
inplace
=
inplace
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
def
gaussian_blur
(
inpt
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
...
@@ -182,7 +179,6 @@ def gaussian_blur_video(
...
@@ -182,7 +179,6 @@ def gaussian_blur_video(
return
gaussian_blur_image_tensor
(
video
,
kernel_size
,
sigma
)
return
gaussian_blur_image_tensor
(
video
,
kernel_size
,
sigma
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
def
to_dtype
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
...
...
torchvision/transforms/v2/functional/_temporal.py
View file @
db310636
import
PIL.Image
import
torch
import
torch
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
from
._utils
import
_get_kernel
,
_register_kernel_internal
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
uniform_temporal_subsample
(
inpt
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
def
uniform_temporal_subsample
(
inpt
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
if
torch
.
jit
.
is_scripting
():
if
torch
.
jit
.
is_scripting
():
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
=
num_samples
)
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
db310636
import
functools
import
functools
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
import
torch
import
torch
...
@@ -53,6 +52,11 @@ def _name_to_dispatcher(name):
...
@@ -53,6 +52,11 @@ def _name_to_dispatcher(name):
)
from
None
)
from
None
_BUILTIN_DATAPOINT_TYPES
=
{
obj
for
obj
in
datapoints
.
__dict__
.
values
()
if
isinstance
(
obj
,
type
)
and
issubclass
(
obj
,
datapoints
.
Datapoint
)
}
def
register_kernel
(
dispatcher
,
datapoint_cls
):
def
register_kernel
(
dispatcher
,
datapoint_cls
):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
...
@@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls):
...
@@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls):
f
"but got
{
dispatcher
}
."
f
"but got
{
dispatcher
}
."
)
)
if
not
(
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)):
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)
and
datapoint_cls
is
not
datapoints
.
Datapoint
):
raise
ValueError
(
raise
ValueError
(
f
"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f
"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f
"but got
{
datapoint_cls
}
."
f
"but got
{
datapoint_cls
}
."
)
)
if
datapoint_cls
in
_BUILTIN_DATAPOINT_TYPES
:
raise
ValueError
(
f
"Kernels cannot be registered for the builtin datapoint classes, but got
{
datapoint_cls
}
"
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
def
_get_kernel
(
dispatcher
,
input_type
):
def
_get_kernel
(
dispatcher
,
input_type
,
*
,
allow_passthrough
=
False
):
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
if
not
registry
:
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for dispatcher
{
dispatcher
.
__name__
}
."
)
raise
ValueError
(
f
"No kernel registered for dispatcher
{
dispatcher
.
__name__
}
."
)
...
@@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type):
...
@@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type):
elif
cls
in
registry
:
elif
cls
in
registry
:
return
registry
[
cls
]
return
registry
[
cls
]
# Note that in the future we are not going to return a noop here, but rather raise the error below
if
allow_passthrough
:
return
_noop
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
raise
TypeError
(
raise
TypeError
(
f
"Dispatcher
{
dispatcher
}
supports inputs of type torch.Tensor, PIL.Image.Image, "
f
"Dispatcher F.
{
dispatcher
.
__name__
}
supports inputs of type
{
registry
.
keys
()
}
, "
f
"and subclasses of torchvision.datapoints.Datapoint, "
f
"but got
{
input_type
}
instead."
f
"but got
{
input_type
}
instead."
)
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def
_register_explicit_noop
(
*
datapoints_classes
,
warn_passthrough
=
False
):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
def
decorator
(
dispatcher
):
for
cls
in
datapoints_classes
:
msg
=
(
f
"F.
{
dispatcher
.
__name__
}
is currently passing through inputs of type datapoints.
{
cls
.
__name__
}
. "
f
"This will likely change in the future."
)
_register_kernel_internal
(
dispatcher
,
cls
,
datapoint_wrapper
=
False
)(
functools
.
partial
(
_noop
,
__msg__
=
msg
if
warn_passthrough
else
None
)
)
return
dispatcher
return
decorator
def
_noop
(
inpt
,
*
args
,
__msg__
=
None
,
**
kwargs
):
if
__msg__
:
warnings
.
warn
(
__msg__
,
UserWarning
,
stacklevel
=
2
)
return
inpt
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def
_register_unsupported_type
(
*
input_types
):
def
kernel
(
inpt
,
*
args
,
__dispatcher_name__
,
**
kwargs
):
raise
TypeError
(
f
"F.
{
__dispatcher_name__
}
does not support inputs of type
{
type
(
inpt
)
}
."
)
def
decorator
(
dispatcher
):
for
input_type
in
input_types
:
_register_kernel_internal
(
dispatcher
,
input_type
,
datapoint_wrapper
=
False
)(
functools
.
partial
(
kernel
,
__dispatcher_name__
=
dispatcher
.
__name__
)
)
return
dispatcher
return
decorator
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
def
_register_five_ten_crop_kernel
(
dispatcher
,
input_type
):
def
_register_five_ten_crop_kernel
_internal
(
dispatcher
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
if
input_type
in
registry
:
if
input_type
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
}
' already has a kernel registered for type '
{
input_type
}
'."
)
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
}
' already has a kernel registered for type '
{
input_type
}
'."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment