Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ee28bb3c
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "4d2fa1908e531c6c815c026533b0e51a10ef9aef"
Unverified
Commit
ee28bb3c
authored
Oct 02, 2023
by
Philip Meier
Committed by
GitHub
Oct 02, 2023
Browse files
cleanup affine grid image kernels (#8004)
parent
f96deba0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
94 deletions
+39
-94
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+2
-1
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+37
-93
No files found.
test/test_transforms_v2_refactored.py
View file @
ee28bb3c
...
@@ -2491,7 +2491,7 @@ class TestElastic:
...
@@ -2491,7 +2491,7 @@ class TestElastic:
interpolation
=
[
transforms
.
InterpolationMode
.
NEAREST
,
transforms
.
InterpolationMode
.
BILINEAR
],
interpolation
=
[
transforms
.
InterpolationMode
.
NEAREST
,
transforms
.
InterpolationMode
.
BILINEAR
],
fill
=
EXHAUSTIVE_TYPE_FILLS
,
fill
=
EXHAUSTIVE_TYPE_FILLS
,
)
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel_image
(
self
,
param
,
value
,
dtype
,
device
):
def
test_kernel_image
(
self
,
param
,
value
,
dtype
,
device
):
image
=
make_image_tensor
(
dtype
=
dtype
,
device
=
device
)
image
=
make_image_tensor
(
dtype
=
dtype
,
device
=
device
)
...
@@ -2502,6 +2502,7 @@ class TestElastic:
...
@@ -2502,6 +2502,7 @@ class TestElastic:
displacement
=
self
.
_make_displacement
(
image
),
displacement
=
self
.
_make_displacement
(
image
),
**
{
param
:
value
},
**
{
param
:
value
},
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
check_scripted_vs_eager
=
not
(
param
==
"fill"
and
isinstance
(
value
,
(
int
,
float
))),
check_cuda_vs_cpu
=
dtype
is
not
torch
.
float16
,
)
)
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensors
.
BoundingBoxFormat
))
@
pytest
.
mark
.
parametrize
(
"format"
,
list
(
tv_tensors
.
BoundingBoxFormat
))
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
ee28bb3c
...
@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
...
@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
def
_apply_grid_transform
(
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
_FillTypeJIT
)
->
torch
.
Tensor
:
input_shape
=
img
.
shape
output_height
,
output_width
=
grid
.
shape
[
1
],
grid
.
shape
[
2
]
num_channels
,
input_height
,
input_width
=
input_shape
[
-
3
:]
output_shape
=
input_shape
[:
-
3
]
+
(
num_channels
,
output_height
,
output_width
)
if
img
.
numel
()
==
0
:
return
img
.
reshape
(
output_shape
)
img
=
img
.
reshape
(
-
1
,
num_channels
,
input_height
,
input_width
)
squashed_batch_size
=
img
.
shape
[
0
]
# We are using context knowledge that grid should have float dtype
# We are using context knowledge that grid should have float dtype
fp
=
img
.
dtype
==
grid
.
dtype
fp
=
img
.
dtype
==
grid
.
dtype
float_img
=
img
if
fp
else
img
.
to
(
grid
.
dtype
)
float_img
=
img
if
fp
else
img
.
to
(
grid
.
dtype
)
shape
=
float_img
.
shape
if
squashed_batch_size
>
1
:
if
shape
[
0
]
>
1
:
# Apply same grid to a batch of images
# Apply same grid to a batch of images
grid
=
grid
.
expand
(
s
hape
[
0
]
,
-
1
,
-
1
,
-
1
)
grid
=
grid
.
expand
(
s
quashed_batch_size
,
-
1
,
-
1
,
-
1
)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if
fill
is
not
None
:
if
fill
is
not
None
:
mask
=
torch
.
ones
((
shape
[
0
],
1
,
shape
[
2
],
shape
[
3
]),
dtype
=
float_img
.
dtype
,
device
=
float_img
.
device
)
mask
=
torch
.
ones
(
(
squashed_batch_size
,
1
,
input_height
,
input_width
),
dtype
=
float_img
.
dtype
,
device
=
float_img
.
device
)
float_img
=
torch
.
cat
((
float_img
,
mask
),
dim
=
1
)
float_img
=
torch
.
cat
((
float_img
,
mask
),
dim
=
1
)
float_img
=
grid_sample
(
float_img
,
grid
,
mode
=
mode
,
padding_mode
=
"zeros"
,
align_corners
=
False
)
float_img
=
grid_sample
(
float_img
,
grid
,
mode
=
mode
,
padding_mode
=
"zeros"
,
align_corners
=
False
)
...
@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
...
@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
img
=
float_img
.
round_
().
to
(
img
.
dtype
)
if
not
fp
else
float_img
img
=
float_img
.
round_
().
to
(
img
.
dtype
)
if
not
fp
else
float_img
return
img
return
img
.
reshape
(
output_shape
)
def
_assert_grid_transform_inputs
(
def
_assert_grid_transform_inputs
(
...
@@ -661,24 +672,10 @@ def affine_image(
...
@@ -661,24 +672,10 @@ def affine_image(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
shape
=
image
.
shape
ndim
=
image
.
ndim
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
needs_unsquash
=
True
elif
ndim
==
3
:
image
=
image
.
unsqueeze
(
0
)
needs_unsquash
=
True
else
:
needs_unsquash
=
False
height
,
width
=
shape
[
-
2
:]
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
scale
,
shear
,
interpolation
,
center
)
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
scale
,
shear
,
interpolation
,
center
)
height
,
width
=
image
.
shape
[
-
2
:]
center_f
=
[
0.0
,
0.0
]
center_f
=
[
0.0
,
0.0
]
if
center
is
not
None
:
if
center
is
not
None
:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
...
@@ -692,12 +689,7 @@ def affine_image(
...
@@ -692,12 +689,7 @@ def affine_image(
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
width
,
oh
=
height
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
width
,
oh
=
height
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
return
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
return
output
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
...
@@ -969,35 +961,26 @@ def rotate_image(
...
@@ -969,35 +961,26 @@ def rotate_image(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
shape
=
image
.
shape
input_height
,
input_width
=
image
.
shape
[
-
2
:]
num_channels
,
height
,
width
=
shape
[
-
3
:]
center_f
=
[
0.0
,
0.0
]
center_f
=
[
0.0
,
0.0
]
if
center
is
not
None
:
if
center
is
not
None
:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f
=
[(
c
-
s
*
0.5
)
for
c
,
s
in
zip
(
center
,
[
width
,
height
])]
center_f
=
[(
c
-
s
*
0.5
)
for
c
,
s
in
zip
(
center
,
[
input_
width
,
input_
height
])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
# we need to set -angle.
matrix
=
_get_inverse_affine_matrix
(
center_f
,
-
angle
,
[
0.0
,
0.0
],
1.0
,
[
0.0
,
0.0
])
matrix
=
_get_inverse_affine_matrix
(
center_f
,
-
angle
,
[
0.0
,
0.0
],
1.0
,
[
0.0
,
0.0
])
if
image
.
numel
()
>
0
:
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
ow
,
oh
=
_compute_affine_output_size
(
matrix
,
width
,
height
)
if
expand
else
(
width
,
height
)
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
ow
,
oh
=
oh
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
new_height
,
new_width
=
output
.
shape
[
-
2
:]
else
:
output
=
image
new_width
,
new_height
=
_compute_affine_output_size
(
matrix
,
width
,
height
)
if
expand
else
(
width
,
height
)
return
output
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
output_width
,
output_height
=
(
_compute_affine_output_size
(
matrix
,
input_width
,
input_height
)
if
expand
else
(
input_width
,
input_height
)
)
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
input_width
,
h
=
input_height
,
ow
=
output_width
,
oh
=
output_height
)
return
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
...
@@ -1509,21 +1492,6 @@ def perspective_image(
...
@@ -1509,21 +1492,6 @@ def perspective_image(
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
return
image
shape
=
image
.
shape
ndim
=
image
.
ndim
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
needs_unsquash
=
True
elif
ndim
==
3
:
image
=
image
.
unsqueeze
(
0
)
needs_unsquash
=
True
else
:
needs_unsquash
=
False
_assert_grid_transform_inputs
(
_assert_grid_transform_inputs
(
image
,
image
,
matrix
=
None
,
matrix
=
None
,
...
@@ -1533,15 +1501,10 @@ def perspective_image(
...
@@ -1533,15 +1501,10 @@ def perspective_image(
coeffs
=
perspective_coeffs
,
coeffs
=
perspective_coeffs
,
)
)
oh
,
ow
=
shape
[
-
2
:]
oh
,
ow
=
image
.
shape
[
-
2
:]
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
grid
=
_perspective_grid
(
perspective_coeffs
,
ow
=
ow
,
oh
=
oh
,
dtype
=
dtype
,
device
=
image
.
device
)
grid
=
_perspective_grid
(
perspective_coeffs
,
ow
=
ow
,
oh
=
oh
,
dtype
=
dtype
,
device
=
image
.
device
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
return
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
return
output
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
...
@@ -1759,12 +1722,7 @@ def elastic_image(
...
@@ -1759,12 +1722,7 @@ def elastic_image(
interpolation
=
_check_interpolation
(
interpolation
)
interpolation
=
_check_interpolation
(
interpolation
)
if
image
.
numel
()
==
0
:
height
,
width
=
image
.
shape
[
-
2
:]
return
image
shape
=
image
.
shape
ndim
=
image
.
ndim
device
=
image
.
device
device
=
image
.
device
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
...
@@ -1775,32 +1733,18 @@ def elastic_image(
...
@@ -1775,32 +1733,18 @@ def elastic_image(
dtype
=
torch
.
float32
dtype
=
torch
.
float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be cast
ed
to float32 and all computations will be done with float32
# displacement will be cast to float32 and all computations will be done with float32
# We can fix this later if needed
# We can fix this later if needed
expected_shape
=
(
1
,
)
+
shape
[
-
2
:]
+
(
2
,
)
expected_shape
=
(
1
,
height
,
width
,
2
)
if
expected_shape
!=
displacement
.
shape
:
if
expected_shape
!=
displacement
.
shape
:
raise
ValueError
(
f
"Argument displacement shape should be
{
expected_shape
}
, but given
{
displacement
.
shape
}
"
)
raise
ValueError
(
f
"Argument displacement shape should be
{
expected_shape
}
, but given
{
displacement
.
shape
}
"
)
if
ndim
>
4
:
grid
=
_create_identity_grid
((
height
,
width
),
device
=
device
,
dtype
=
dtype
).
add_
(
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
needs_unsquash
=
True
)
elif
ndim
==
3
:
image
=
image
.
unsqueeze
(
0
)
needs_unsquash
=
True
else
:
needs_unsquash
=
False
if
displacement
.
dtype
!=
dtype
or
displacement
.
device
!=
device
:
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
image_height
,
image_width
=
shape
[
-
2
:]
grid
=
_create_identity_grid
((
image_height
,
image_width
),
device
=
device
,
dtype
=
dtype
).
add_
(
displacement
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
if
is_cpu_half
:
if
is_cpu_half
:
output
=
output
.
to
(
torch
.
float16
)
output
=
output
.
to
(
torch
.
float16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment