Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ee28bb3c
Unverified
Commit
ee28bb3c
authored
Oct 02, 2023
by
Philip Meier
Committed by
GitHub
Oct 02, 2023
Browse files
cleanup affine grid image kernels (#8004)
parent
f96deba0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
94 deletions
+39
-94
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+2
-1
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+37
-93
No files found.
test/test_transforms_v2_refactored.py
View file @
ee28bb3c
...
@@ -2491,7 +2491,7 @@ class TestElastic:
...
@@ -2491,7 +2491,7 @@ class TestElastic:
interpolation
=
[
transforms
.
InterpolationMode
.
NEAREST
,
transforms
.
InterpolationMode
.
BILINEAR
],
interpolation
=
[
transforms
.
InterpolationMode
.
NEAREST
,
transforms
.
InterpolationMode
.
BILINEAR
],
fill
=
EXHAUSTIVE_TYPE_FILLS
,
fill
=
EXHAUSTIVE_TYPE_FILLS
,
)
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image
(
self
,
param
,
value
,
dtype
,
device
):
def
test_kernel_image
(
self
,
param
,
value
,
dtype
,
device
):
image
=
make_image_tensor
(
dtype
=
dtype
,
device
=
device
)
image
=
make_image_tensor
(
dtype
=
dtype
,
device
=
device
)
...
@@ -2502,6 +2502,7 @@ class TestElastic:
...
@@ -2502,6 +2502,7 @@ class TestElastic:
displacement
=
self
.
_make_displacement
(
image
),
displacement
=
self
.
_make_displacement
(
image
),
**
{
param
:
value
},
**
{
param
:
value
},
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
check_cuda_vs_cpu
=
dtype
is
not
torch
.
float16
,
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensors
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensors
.
BoundingBoxFormat
))
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
ee28bb3c
...
@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
...
@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
input_shape
=
img
.
shape
output_height
,
output_width
=
grid
.
shape
[
1
],
grid
.
shape
[
2
]
num_channels
,
input_height
,
input_width
=
input_shape
[
-
3
:]
output_shape
=
input_shape
[:
-
3
]
+
(
num_channels
,
output_height
,
output_width
)
if
img
.
numel
()
==
0
:
return
img
.
reshape
(
output_shape
)
img
=
img
.
reshape
(
-
1
,
num_channels
,
input_height
,
input_width
)
squashed_batch_size
=
img
.
shape
[
0
]
# We are using context knowledge that grid should have float dtype
# We are using context knowledge that grid should have float dtype
fp
=
img
.
dtype
==
grid
.
dtype
fp
=
img
.
dtype
==
grid
.
dtype
float_img
=
img
if
fp
else
img
.
to
(
grid
.
dtype
)
float_img
=
img
if
fp
else
img
.
to
(
grid
.
dtype
)
shape
=
float_img
.
shape
if
squashed_batch_size
>
1
:
if
shape
[
0
]
>
1
:
# Apply same grid to a batch of images
# Apply same grid to a batch of images
grid
=
grid
.
expand
(
s
hape
[
0
]
,
-
1
,
-
1
,
-
1
)
grid
=
grid
.
expand
(
s
quashed_batch_size
,
-
1
,
-
1
,
-
1
)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if
fill
is
not
None
:
if
fill
is
not
None
:
mask
=
torch
.
ones
((
shape
[
0
],
1
,
shape
[
2
],
shape
[
3
]),
dtype
=
float_img
.
dtype
,
device
=
float_img
.
device
)
mask
=
torch
.
ones
(
(
squashed_batch_size
,
1
,
input_height
,
input_width
),
dtype
=
float_img
.
dtype
,
device
=
float_img
.
device
)
float_img
=
torch
.
cat
((
float_img
,
mask
),
dim
=
1
)
float_img
=
torch
.
cat
((
float_img
,
mask
),
dim
=
1
)
float_img
=
grid_sample
(
float_img
,
grid
,
mode
=
mode
,
padding_mode
=
"zeros"
,
align_corners
=
False
)
float_img
=
grid_sample
(
float_img
,
grid
,
mode
=
mode
,
padding_mode
=
"zeros"
,
align_corners
=
False
)
...
@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
...
@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
img
=
float_img
.
round_
().
to
(
img
.
dtype
)
if
not
fp
else
float_img
img
=
float_img
.
round_
().
to
(
img
.
dtype
)
if
not
fp
else
float_img
return
img
return
img
.
reshape
(
output_shape
)
def
_assert_grid_transform_inputs
(
def
_assert_grid_transform_inputs
(
...
@@ -661,24 +672,10 @@ def affine_image(
...
@@ -661,24 +672,10 @@ def affine_image(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
shape
=
image
.
shape
ndim
=
image
.
ndim
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
needs_unsquash
=
True
elif
ndim
==
3
:
image
=
image
.
unsqueeze
(
0
)
needs_unsquash
=
True
else
:
needs_unsquash
=
False
height
,
width
=
shape
[
-
2
:]
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
scale
,
shear
,
interpolation
,
center
)
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
scale
,
shear
,
interpolation
,
center
)
height
,
width
=
image
.
shape
[
-
2
:]
center_f
=
[
0.0
,
0.0
]
center_f
=
[
0.0
,
0.0
]
if
center
is
not
None
:
if
center
is
not
None
:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
...
@@ -692,12 +689,7 @@ def affine_image(
...
@@ -692,12 +689,7 @@ def affine_image(
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
width
,
oh
=
height
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
width
,
oh
=
height
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
return
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
return
output
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
...
@@ -969,35 +961,26 @@ def rotate_image(
...
@@ -969,35 +961,26 @@ def rotate_image(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
shape
=
image
.
shape
input_height
,
input_width
=
image
.
shape
[
-
2
:]
num_channels
,
height
,
width
=
shape
[
-
3
:]
center_f
=
[
0.0
,
0.0
]
center_f
=
[
0.0
,
0.0
]
if
center
is
not
None
:
if
center
is
not
None
:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f
=
[(
c
-
s
*
0.5
)
for
c
,
s
in
zip
(
center
,
[
width
,
height
])]
center_f
=
[(
c
-
s
*
0.5
)
for
c
,
s
in
zip
(
center
,
[
input_
width
,
input_
height
])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
# we need to set -angle.
matrix
=
_get_inverse_affine_matrix
(
center_f
,
-
angle
,
[
0.0
,
0.0
],
1.0
,
[
0.0
,
0.0
])
matrix
=
_get_inverse_affine_matrix
(
center_f
,
-
angle
,
[
0.0
,
0.0
],
1.0
,
[
0.0
,
0.0
])
if
image
.
numel
()
>
0
:
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
ow
,
oh
=
_compute_affine_output_size
(
matrix
,
width
,
height
)
if
expand
else
(
width
,
height
)
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
ow
,
oh
=
oh
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
new_height
,
new_width
=
output
.
shape
[
-
2
:]
else
:
output
=
image
new_width
,
new_height
=
_compute_affine_output_size
(
matrix
,
width
,
height
)
if
expand
else
(
width
,
height
)
return
output
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
output_width
,
output_height
=
(
_compute_affine_output_size
(
matrix
,
input_width
,
input_height
)
if
expand
else
(
input_width
,
input_height
)
)
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
input_width
,
h
=
input_height
,
ow
=
output_width
,
oh
=
output_height
)
return
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
...
@@ -1509,21 +1492,6 @@ def perspective_image(
...
@@ -1509,21 +1492,6 @@ def perspective_image(
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
shape
=
image
.
shape
ndim
=
image
.
ndim
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
needs_unsquash
=
True
elif
ndim
==
3
:
image
=
image
.
unsqueeze
(
0
)
needs_unsquash
=
True
else
:
needs_unsquash
=
False
_assert_grid_transform_inputs
(
_assert_grid_transform_inputs
(
image
,
image
,
matrix
=
None
,
matrix
=
None
,
...
@@ -1533,15 +1501,10 @@ def perspective_image(
...
@@ -1533,15 +1501,10 @@ def perspective_image(
coeffs
=
perspective_coeffs
,
coeffs
=
perspective_coeffs
,
)
)
oh
,
ow
=
shape
[
-
2
:]
oh
,
ow
=
image
.
shape
[
-
2
:]
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
grid
=
_perspective_grid
(
perspective_coeffs
,
ow
=
ow
,
oh
=
oh
,
dtype
=
dtype
,
device
=
image
.
device
)
grid
=
_perspective_grid
(
perspective_coeffs
,
ow
=
ow
,
oh
=
oh
,
dtype
=
dtype
,
device
=
image
.
device
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
return
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
return
output
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
...
@@ -1759,12 +1722,7 @@ def elastic_image(
...
@@ -1759,12 +1722,7 @@ def elastic_image(
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
height
,
width
=
image
.
shape
[
-
2
:]
return
image
shape
=
image
.
shape
ndim
=
image
.
ndim
device
=
image
.
device
device
=
image
.
device
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
...
@@ -1775,32 +1733,18 @@ def elastic_image(
...
@@ -1775,32 +1733,18 @@ def elastic_image(
dtype
=
torch
.
float32
dtype
=
torch
.
float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be cast
ed
to float32 and all computations will be done with float32
# displacement will be cast to float32 and all computations will be done with float32
# We can fix this later if needed
# We can fix this later if needed
expected_shape
=
(
1
,
)
+
shape
[
-
2
:]
+
(
2
,
)
expected_shape
=
(
1
,
height
,
width
,
2
)
if
expected_shape
!=
displacement
.
shape
:
if
expected_shape
!=
displacement
.
shape
:
raise
ValueError
(
f
"Argument displacement shape should be
{
expected_shape
}
, but given
{
displacement
.
shape
}
"
)
raise
ValueError
(
f
"Argument displacement shape should be
{
expected_shape
}
, but given
{
displacement
.
shape
}
"
)
if
ndim
>
4
:
grid
=
_create_identity_grid
((
height
,
width
),
device
=
device
,
dtype
=
dtype
).
add_
(
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
needs_unsquash
=
True
)
elif
ndim
==
3
:
image
=
image
.
unsqueeze
(
0
)
needs_unsquash
=
True
else
:
needs_unsquash
=
False
if
displacement
.
dtype
!=
dtype
or
displacement
.
device
!=
device
:
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
image_height
,
image_width
=
shape
[
-
2
:]
grid
=
_create_identity_grid
((
image_height
,
image_width
),
device
=
device
,
dtype
=
dtype
).
add_
(
displacement
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
if
is_cpu_half
:
if
is_cpu_half
:
output
=
output
.
to
(
torch
.
float16
)
output
=
output
.
to
(
torch
.
float16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment