Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
af7c6c04
Unverified
Commit
af7c6c04
authored
Feb 13, 2023
by
vfdev
Committed by
GitHub
Feb 13, 2023
Browse files
Fixed issues with dtype in geom functional transforms v2 (#7211)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
ea37cd38
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
56 deletions
+101
-56
test/prototype_common_utils.py
test/prototype_common_utils.py
+3
-3
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+21
-2
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+22
-11
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+1
-1
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+8
-1
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+37
-35
torchvision/transforms/transforms.py
torchvision/transforms/transforms.py
+9
-3
No files found.
test/prototype_common_utils.py
View file @
af7c6c04
...
@@ -304,7 +304,7 @@ def make_image_loaders(
...
@@ -304,7 +304,7 @@ def make_image_loaders(
"RGBA"
,
"RGBA"
,
),
),
extra_dims
=
DEFAULT_EXTRA_DIMS
,
extra_dims
=
DEFAULT_EXTRA_DIMS
,
dtypes
=
(
torch
.
float32
,
torch
.
uint8
),
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
uint8
),
constant_alpha
=
True
,
constant_alpha
=
True
,
):
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
extra_dims
=
extra_dims
,
dtype
=
dtypes
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
extra_dims
=
extra_dims
,
dtype
=
dtypes
):
...
@@ -426,7 +426,7 @@ def make_bounding_box_loaders(
...
@@ -426,7 +426,7 @@ def make_bounding_box_loaders(
extra_dims
=
DEFAULT_EXTRA_DIMS
,
extra_dims
=
DEFAULT_EXTRA_DIMS
,
formats
=
tuple
(
datapoints
.
BoundingBoxFormat
),
formats
=
tuple
(
datapoints
.
BoundingBoxFormat
),
spatial_size
=
"random"
,
spatial_size
=
"random"
,
dtypes
=
(
torch
.
float32
,
torch
.
int64
),
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
):
):
for
params
in
combinations_grid
(
extra_dims
=
extra_dims
,
format
=
formats
,
dtype
=
dtypes
):
for
params
in
combinations_grid
(
extra_dims
=
extra_dims
,
format
=
formats
,
dtype
=
dtypes
):
yield
make_bounding_box_loader
(
**
params
,
spatial_size
=
spatial_size
)
yield
make_bounding_box_loader
(
**
params
,
spatial_size
=
spatial_size
)
...
@@ -618,7 +618,7 @@ def make_video_loaders(
...
@@ -618,7 +618,7 @@ def make_video_loaders(
),
),
num_frames
=
(
1
,
0
,
"random"
),
num_frames
=
(
1
,
0
,
"random"
),
extra_dims
=
DEFAULT_EXTRA_DIMS
,
extra_dims
=
DEFAULT_EXTRA_DIMS
,
dtypes
=
(
torch
.
uint8
,),
dtypes
=
(
torch
.
uint8
,
torch
.
float32
,
torch
.
float64
),
):
):
for
params
in
combinations_grid
(
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
num_frames
=
num_frames
,
extra_dims
=
extra_dims
,
dtype
=
dtypes
size
=
sizes
,
color_space
=
color_spaces
,
num_frames
=
num_frames
,
extra_dims
=
extra_dims
,
dtype
=
dtypes
...
...
test/prototype_transforms_kernel_infos.py
View file @
af7c6c04
...
@@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
...
@@ -109,6 +109,12 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
}
}
def
scripted_vs_eager_double_pixel_difference
(
device
,
atol
=
1e-6
,
rtol
=
1e-6
):
return
{
((
"TestKernels"
,
"test_scripted_vs_eager"
),
torch
.
float64
,
device
):
{
"atol"
:
atol
,
"rtol"
:
rtol
,
"mae"
:
False
},
}
def
pil_reference_wrapper
(
pil_kernel
):
def
pil_reference_wrapper
(
pil_kernel
):
@
functools
.
wraps
(
pil_kernel
)
@
functools
.
wraps
(
pil_kernel
)
def
wrapper
(
input_tensor
,
*
other_args
,
**
kwargs
):
def
wrapper
(
input_tensor
,
*
other_args
,
**
kwargs
):
...
@@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
...
@@ -541,8 +547,10 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
def
transform
(
bbox
,
affine_matrix_
,
format_
):
def
transform
(
bbox
,
affine_matrix_
,
format_
):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype
=
bbox
.
dtype
in_dtype
=
bbox
.
dtype
if
not
torch
.
is_floating_point
(
bbox
):
bbox
=
bbox
.
float
()
bbox_xyxy
=
F
.
convert_format_bounding_box
(
bbox_xyxy
=
F
.
convert_format_bounding_box
(
bbox
.
float
()
,
old_format
=
format_
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bbox
,
old_format
=
format_
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
)
points
=
np
.
array
(
points
=
np
.
array
(
[
[
...
@@ -560,6 +568,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
...
@@ -560,6 +568,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
np
.
max
(
transformed_points
[:,
0
]).
item
(),
np
.
max
(
transformed_points
[:,
0
]).
item
(),
np
.
max
(
transformed_points
[:,
1
]).
item
(),
np
.
max
(
transformed_points
[:,
1
]).
item
(),
],
],
dtype
=
bbox_xyxy
.
dtype
,
)
)
out_bbox
=
F
.
convert_format_bounding_box
(
out_bbox
=
F
.
convert_format_bounding_box
(
out_bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
out_bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
...
@@ -844,6 +853,10 @@ KERNEL_INFOS.extend(
...
@@ -844,6 +853,10 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
rotate_bounding_box
,
F
.
rotate_bounding_box
,
sample_inputs_fn
=
sample_inputs_rotate_bounding_box
,
sample_inputs_fn
=
sample_inputs_rotate_bounding_box
,
closeness_kwargs
=
{
**
scripted_vs_eager_double_pixel_difference
(
"cpu"
,
atol
=
1e-6
,
rtol
=
1e-6
),
**
scripted_vs_eager_double_pixel_difference
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
},
),
),
KernelInfo
(
KernelInfo
(
F
.
rotate_mask
,
F
.
rotate_mask
,
...
@@ -1275,6 +1288,8 @@ KERNEL_INFOS.extend(
...
@@ -1275,6 +1288,8 @@ KERNEL_INFOS.extend(
**
pil_reference_pixel_difference
(
2
,
mae
=
True
),
**
pil_reference_pixel_difference
(
2
,
mae
=
True
),
**
cuda_vs_cpu_pixel_difference
(),
**
cuda_vs_cpu_pixel_difference
(),
**
float32_vs_uint8_pixel_difference
(),
**
float32_vs_uint8_pixel_difference
(),
**
scripted_vs_eager_double_pixel_difference
(
"cpu"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_double_pixel_difference
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
},
},
),
),
KernelInfo
(
KernelInfo
(
...
@@ -1294,7 +1309,11 @@ KERNEL_INFOS.extend(
...
@@ -1294,7 +1309,11 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
perspective_video
,
F
.
perspective_video
,
sample_inputs_fn
=
sample_inputs_perspective_video
,
sample_inputs_fn
=
sample_inputs_perspective_video
,
closeness_kwargs
=
cuda_vs_cpu_pixel_difference
(),
closeness_kwargs
=
{
**
cuda_vs_cpu_pixel_difference
(),
**
scripted_vs_eager_double_pixel_difference
(
"cpu"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_double_pixel_difference
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
},
),
),
]
]
)
)
...
...
test/test_prototype_transforms_consistency.py
View file @
af7c6c04
...
@@ -138,17 +138,28 @@ CONSISTENCY_CONFIGS = [
...
@@ -138,17 +138,28 @@ CONSISTENCY_CONFIGS = [
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"symmetric"
),
NotScriptableArgsKwargs
(
5
,
padding_mode
=
"symmetric"
),
],
],
),
),
*
[
ConsistencyConfig
(
ConsistencyConfig
(
prototype_transforms
.
LinearTransformation
,
prototype_transforms
.
LinearTransformation
,
legacy_transforms
.
LinearTransformation
,
legacy_transforms
.
LinearTransformation
,
[
[
ArgsKwargs
(
LINEAR_TRANSFORMATION_MATRIX
,
LINEAR_TRANSFORMATION_MEAN
),
ArgsKwargs
(
LINEAR_TRANSFORMATION_MATRIX
.
to
(
matrix_dtype
)
,
LINEAR_TRANSFORMATION_MEAN
.
to
(
matrix_dtype
)
),
],
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs
=
dict
(
DEFAULT_MAKE_IMAGES_KWARGS
,
sizes
=
[(
2
,
6
),
(
4
,
3
)],
color_spaces
=
[
"RGB"
]),
make_images_kwargs
=
dict
(
supports_pil
=
False
,
DEFAULT_MAKE_IMAGES_KWARGS
,
sizes
=
[(
2
,
6
),
(
4
,
3
)],
color_spaces
=
[
"RGB"
],
dtypes
=
[
image_dtype
]
),
),
supports_pil
=
False
,
)
for
matrix_dtype
,
image_dtype
in
[
(
torch
.
float32
,
torch
.
float32
),
(
torch
.
float64
,
torch
.
float64
),
(
torch
.
float32
,
torch
.
uint8
),
(
torch
.
float64
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float64
),
]
],
ConsistencyConfig
(
ConsistencyConfig
(
prototype_transforms
.
Grayscale
,
prototype_transforms
.
Grayscale
,
legacy_transforms
.
Grayscale
,
legacy_transforms
.
Grayscale
,
...
...
test/test_prototype_transforms_functional.py
View file @
af7c6c04
...
@@ -142,7 +142,7 @@ class TestKernels:
...
@@ -142,7 +142,7 @@ class TestKernels:
actual
,
actual
,
expected
,
expected
,
**
info
.
get_closeness_kwargs
(
test_id
,
dtype
=
input
.
dtype
,
device
=
input
.
device
),
**
info
.
get_closeness_kwargs
(
test_id
,
dtype
=
input
.
dtype
,
device
=
input
.
device
),
msg
=
parametrized_error_message
(
*
other_args
,
**
kwargs
),
msg
=
parametrized_error_message
(
*
([
actual
,
expected
]
+
other_args
)
,
**
kwargs
),
)
)
def
_unbatch
(
self
,
batch
,
*
,
data_dims
):
def
_unbatch
(
self
,
batch
,
*
,
data_dims
):
...
...
torchvision/prototype/transforms/_misc.py
View file @
af7c6c04
...
@@ -64,6 +64,11 @@ class LinearTransformation(Transform):
...
@@ -64,6 +64,11 @@ class LinearTransformation(Transform):
f
"Input tensors should be on the same device. Got
{
transformation_matrix
.
device
}
and
{
mean_vector
.
device
}
"
f
"Input tensors should be on the same device. Got
{
transformation_matrix
.
device
}
and
{
mean_vector
.
device
}
"
)
)
if
transformation_matrix
.
dtype
!=
mean_vector
.
dtype
:
raise
ValueError
(
f
"Input tensors should have the same dtype. Got
{
transformation_matrix
.
dtype
}
and
{
mean_vector
.
dtype
}
"
)
self
.
transformation_matrix
=
transformation_matrix
self
.
transformation_matrix
=
transformation_matrix
self
.
mean_vector
=
mean_vector
self
.
mean_vector
=
mean_vector
...
@@ -93,7 +98,9 @@ class LinearTransformation(Transform):
...
@@ -93,7 +98,9 @@ class LinearTransformation(Transform):
)
)
flat_tensor
=
inpt
.
reshape
(
-
1
,
n
)
-
self
.
mean_vector
flat_tensor
=
inpt
.
reshape
(
-
1
,
n
)
-
self
.
mean_vector
transformed_tensor
=
torch
.
mm
(
flat_tensor
,
self
.
transformation_matrix
)
transformation_matrix
=
self
.
transformation_matrix
.
to
(
flat_tensor
.
dtype
)
transformed_tensor
=
torch
.
mm
(
flat_tensor
,
transformation_matrix
)
return
transformed_tensor
.
reshape
(
shape
)
return
transformed_tensor
.
reshape
(
shape
)
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
af7c6c04
...
@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
...
@@ -404,9 +404,13 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def
_apply_grid_transform
(
def
_apply_grid_transform
(
float_
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
datapoints
.
FillTypeJIT
img
:
torch
.
Tensor
,
grid
:
torch
.
Tensor
,
mode
:
str
,
fill
:
datapoints
.
FillTypeJIT
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# We are using context knowledge that grid should have float dtype
fp
=
img
.
dtype
==
grid
.
dtype
float_img
=
img
if
fp
else
img
.
to
(
grid
.
dtype
)
shape
=
float_img
.
shape
shape
=
float_img
.
shape
if
shape
[
0
]
>
1
:
if
shape
[
0
]
>
1
:
# Apply same grid to a batch of images
# Apply same grid to a batch of images
...
@@ -433,7 +437,9 @@ def _apply_grid_transform(
...
@@ -433,7 +437,9 @@ def _apply_grid_transform(
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img
=
float_img
.
sub_
(
fill_img
).
mul_
(
mask
).
add_
(
fill_img
)
float_img
=
float_img
.
sub_
(
fill_img
).
mul_
(
mask
).
add_
(
fill_img
)
return
float_img
img
=
float_img
.
round_
().
to
(
img
.
dtype
)
if
not
fp
else
float_img
return
img
def
_assert_grid_transform_inputs
(
def
_assert_grid_transform_inputs
(
...
@@ -511,7 +517,6 @@ def affine_image_tensor(
...
@@ -511,7 +517,6 @@ def affine_image_tensor(
shape
=
image
.
shape
shape
=
image
.
shape
ndim
=
image
.
ndim
ndim
=
image
.
ndim
fp
=
torch
.
is_floating_point
(
image
)
if
ndim
>
4
:
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
...
@@ -535,13 +540,10 @@ def affine_image_tensor(
...
@@ -535,13 +540,10 @@ def affine_image_tensor(
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
dtype
=
image
.
dtype
if
fp
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
width
,
oh
=
height
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
width
,
oh
=
height
)
output
=
_apply_grid_transform
(
image
if
fp
else
image
.
to
(
dtype
),
grid
,
interpolation
.
value
,
fill
=
fill
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
not
fp
:
output
=
output
.
round_
().
to
(
image
.
dtype
)
if
needs_unsquash
:
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
output
=
output
.
reshape
(
shape
)
...
@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
...
@@ -612,7 +614,7 @@ def _affine_bounding_box_xyxy(
# Single point structure is similar to
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points
=
bounding_box
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
bounding_box
[:,
[[
0
,
1
],
[
2
,
1
],
[
2
,
3
],
[
0
,
3
]]].
reshape
(
-
1
,
2
)
points
=
torch
.
cat
([
points
,
torch
.
ones
(
points
.
shape
[
0
],
1
,
device
=
points
.
devic
e
)],
dim
=-
1
)
points
=
torch
.
cat
([
points
,
torch
.
ones
(
points
.
shape
[
0
],
1
,
device
=
device
,
dtype
=
dtyp
e
)],
dim
=-
1
)
# 2) Now let's transform the points using affine matrix
# 2) Now let's transform the points using affine matrix
transformed_points
=
torch
.
matmul
(
points
,
transposed_affine_matrix
)
transformed_points
=
torch
.
matmul
(
points
,
transposed_affine_matrix
)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
...
@@ -797,19 +799,15 @@ def rotate_image_tensor(
...
@@ -797,19 +799,15 @@ def rotate_image_tensor(
matrix
=
_get_inverse_affine_matrix
(
center_f
,
-
angle
,
[
0.0
,
0.0
],
1.0
,
[
0.0
,
0.0
])
matrix
=
_get_inverse_affine_matrix
(
center_f
,
-
angle
,
[
0.0
,
0.0
],
1.0
,
[
0.0
,
0.0
])
if
image
.
numel
()
>
0
:
if
image
.
numel
()
>
0
:
fp
=
torch
.
is_floating_point
(
image
)
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
image
=
image
.
reshape
(
-
1
,
num_channels
,
height
,
width
)
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
_assert_grid_transform_inputs
(
image
,
matrix
,
interpolation
.
value
,
fill
,
[
"nearest"
,
"bilinear"
])
ow
,
oh
=
_compute_affine_output_size
(
matrix
,
width
,
height
)
if
expand
else
(
width
,
height
)
ow
,
oh
=
_compute_affine_output_size
(
matrix
,
width
,
height
)
if
expand
else
(
width
,
height
)
dtype
=
image
.
dtype
if
fp
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
theta
=
torch
.
tensor
(
matrix
,
dtype
=
dtype
,
device
=
image
.
device
).
reshape
(
1
,
2
,
3
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
ow
,
oh
=
oh
)
grid
=
_affine_grid
(
theta
,
w
=
width
,
h
=
height
,
ow
=
ow
,
oh
=
oh
)
output
=
_apply_grid_transform
(
image
if
fp
else
image
.
to
(
dtype
),
grid
,
interpolation
.
value
,
fill
=
fill
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
not
fp
:
output
=
output
.
round_
().
to
(
image
.
dtype
)
new_height
,
new_width
=
output
.
shape
[
-
2
:]
new_height
,
new_width
=
output
.
shape
[
-
2
:]
else
:
else
:
...
@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
...
@@ -1237,9 +1235,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
d
=
0.5
d
=
0.5
base_grid
=
torch
.
empty
(
1
,
oh
,
ow
,
3
,
dtype
=
dtype
,
device
=
device
)
base_grid
=
torch
.
empty
(
1
,
oh
,
ow
,
3
,
dtype
=
dtype
,
device
=
device
)
x_grid
=
torch
.
linspace
(
d
,
ow
+
d
-
1.0
,
steps
=
ow
,
device
=
device
)
x_grid
=
torch
.
linspace
(
d
,
ow
+
d
-
1.0
,
steps
=
ow
,
device
=
device
,
dtype
=
dtype
)
base_grid
[...,
0
].
copy_
(
x_grid
)
base_grid
[...,
0
].
copy_
(
x_grid
)
y_grid
=
torch
.
linspace
(
d
,
oh
+
d
-
1.0
,
steps
=
oh
,
device
=
device
).
unsqueeze_
(
-
1
)
y_grid
=
torch
.
linspace
(
d
,
oh
+
d
-
1.0
,
steps
=
oh
,
device
=
device
,
dtype
=
dtype
).
unsqueeze_
(
-
1
)
base_grid
[...,
1
].
copy_
(
y_grid
)
base_grid
[...,
1
].
copy_
(
y_grid
)
base_grid
[...,
2
].
fill_
(
1
)
base_grid
[...,
2
].
fill_
(
1
)
...
@@ -1283,7 +1281,6 @@ def perspective_image_tensor(
...
@@ -1283,7 +1281,6 @@ def perspective_image_tensor(
shape
=
image
.
shape
shape
=
image
.
shape
ndim
=
image
.
ndim
ndim
=
image
.
ndim
fp
=
torch
.
is_floating_point
(
image
)
if
ndim
>
4
:
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
...
@@ -1304,12 +1301,9 @@ def perspective_image_tensor(
...
@@ -1304,12 +1301,9 @@ def perspective_image_tensor(
)
)
oh
,
ow
=
shape
[
-
2
:]
oh
,
ow
=
shape
[
-
2
:]
dtype
=
image
.
dtype
if
fp
else
torch
.
float32
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
grid
=
_perspective_grid
(
perspective_coeffs
,
ow
=
ow
,
oh
=
oh
,
dtype
=
dtype
,
device
=
image
.
device
)
grid
=
_perspective_grid
(
perspective_coeffs
,
ow
=
ow
,
oh
=
oh
,
dtype
=
dtype
,
device
=
image
.
device
)
output
=
_apply_grid_transform
(
image
if
fp
else
image
.
to
(
dtype
),
grid
,
interpolation
.
value
,
fill
=
fill
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
not
fp
:
output
=
output
.
round_
().
to
(
image
.
dtype
)
if
needs_unsquash
:
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
output
=
output
.
reshape
(
shape
)
...
@@ -1494,8 +1488,12 @@ def elastic_image_tensor(
...
@@ -1494,8 +1488,12 @@ def elastic_image_tensor(
shape
=
image
.
shape
shape
=
image
.
shape
ndim
=
image
.
ndim
ndim
=
image
.
ndim
device
=
image
.
device
device
=
image
.
device
fp
=
torch
.
is_floating_point
(
image
)
dtype
=
image
.
dtype
if
torch
.
is_floating_point
(
image
)
else
torch
.
float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# We can fix this later if needed
if
ndim
>
4
:
if
ndim
>
4
:
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
image
=
image
.
reshape
((
-
1
,)
+
shape
[
-
3
:])
...
@@ -1506,12 +1504,12 @@ def elastic_image_tensor(
...
@@ -1506,12 +1504,12 @@ def elastic_image_tensor(
else
:
else
:
needs_unsquash
=
False
needs_unsquash
=
False
image_height
,
image_width
=
shape
[
-
2
:]
if
displacement
.
dtype
!=
dtype
or
displacement
.
device
!=
device
:
grid
=
_create_identity_grid
((
image_height
,
image_width
),
device
=
device
).
add_
(
displacement
.
to
(
device
))
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
output
=
_apply_grid_transform
(
image
if
fp
else
image
.
to
(
torch
.
float32
),
grid
,
interpolation
.
value
,
fill
=
fill
)
if
not
fp
:
image_height
,
image_width
=
shape
[
-
2
:]
output
=
output
.
round_
().
to
(
image
.
dtype
)
grid
=
_create_identity_grid
((
image_height
,
image_width
),
device
=
device
,
dtype
=
dtype
).
add_
(
displacement
)
output
=
_apply_grid_transform
(
image
,
grid
,
interpolation
.
value
,
fill
=
fill
)
if
needs_unsquash
:
if
needs_unsquash
:
output
=
output
.
reshape
(
shape
)
output
=
output
.
reshape
(
shape
)
...
@@ -1531,13 +1529,13 @@ def elastic_image_pil(
...
@@ -1531,13 +1529,13 @@ def elastic_image_pil(
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
def
_create_identity_grid
(
size
:
Tuple
[
int
,
int
],
device
:
torch
.
device
)
->
torch
.
Tensor
:
def
_create_identity_grid
(
size
:
Tuple
[
int
,
int
],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
sy
,
sx
=
size
sy
,
sx
=
size
base_grid
=
torch
.
empty
(
1
,
sy
,
sx
,
2
,
device
=
device
)
base_grid
=
torch
.
empty
(
1
,
sy
,
sx
,
2
,
device
=
device
,
dtype
=
dtype
)
x_grid
=
torch
.
linspace
((
-
sx
+
1
)
/
sx
,
(
sx
-
1
)
/
sx
,
sx
,
device
=
device
)
x_grid
=
torch
.
linspace
((
-
sx
+
1
)
/
sx
,
(
sx
-
1
)
/
sx
,
sx
,
device
=
device
,
dtype
=
dtype
)
base_grid
[...,
0
].
copy_
(
x_grid
)
base_grid
[...,
0
].
copy_
(
x_grid
)
y_grid
=
torch
.
linspace
((
-
sy
+
1
)
/
sy
,
(
sy
-
1
)
/
sy
,
sy
,
device
=
device
).
unsqueeze_
(
-
1
)
y_grid
=
torch
.
linspace
((
-
sy
+
1
)
/
sy
,
(
sy
-
1
)
/
sy
,
sy
,
device
=
device
,
dtype
=
dtype
).
unsqueeze_
(
-
1
)
base_grid
[...,
1
].
copy_
(
y_grid
)
base_grid
[...,
1
].
copy_
(
y_grid
)
return
base_grid
return
base_grid
...
@@ -1552,7 +1550,11 @@ def elastic_bounding_box(
...
@@ -1552,7 +1550,11 @@ def elastic_bounding_box(
return
bounding_box
return
bounding_box
# TODO: add in docstring about approximation we are doing for grid inversion
# TODO: add in docstring about approximation we are doing for grid inversion
displacement
=
displacement
.
to
(
bounding_box
.
device
)
device
=
bounding_box
.
device
dtype
=
bounding_box
.
dtype
if
torch
.
is_floating_point
(
bounding_box
)
else
torch
.
float32
if
displacement
.
dtype
!=
dtype
or
displacement
.
device
!=
device
:
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
original_shape
=
bounding_box
.
shape
original_shape
=
bounding_box
.
shape
bounding_box
=
(
bounding_box
=
(
...
@@ -1563,7 +1565,7 @@ def elastic_bounding_box(
...
@@ -1563,7 +1565,7 @@ def elastic_bounding_box(
# Or add spatial_size arg and check displacement shape
# Or add spatial_size arg and check displacement shape
spatial_size
=
displacement
.
shape
[
-
3
],
displacement
.
shape
[
-
2
]
spatial_size
=
displacement
.
shape
[
-
3
],
displacement
.
shape
[
-
2
]
id_grid
=
_create_identity_grid
(
spatial_size
,
bounding_box
.
devic
e
)
id_grid
=
_create_identity_grid
(
spatial_size
,
device
=
device
,
dtype
=
dtyp
e
)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
# This is not an exact inverse of the grid
inv_grid
=
id_grid
.
sub_
(
displacement
)
inv_grid
=
id_grid
.
sub_
(
displacement
)
...
...
torchvision/transforms/transforms.py
View file @
af7c6c04
...
@@ -1078,6 +1078,11 @@ class LinearTransformation(torch.nn.Module):
...
@@ -1078,6 +1078,11 @@ class LinearTransformation(torch.nn.Module):
f
"Input tensors should be on the same device. Got
{
transformation_matrix
.
device
}
and
{
mean_vector
.
device
}
"
f
"Input tensors should be on the same device. Got
{
transformation_matrix
.
device
}
and
{
mean_vector
.
device
}
"
)
)
if
transformation_matrix
.
dtype
!=
mean_vector
.
dtype
:
raise
ValueError
(
f
"Input tensors should have the same dtype. Got
{
transformation_matrix
.
dtype
}
and
{
mean_vector
.
dtype
}
"
)
self
.
transformation_matrix
=
transformation_matrix
self
.
transformation_matrix
=
transformation_matrix
self
.
mean_vector
=
mean_vector
self
.
mean_vector
=
mean_vector
...
@@ -1105,9 +1110,10 @@ class LinearTransformation(torch.nn.Module):
...
@@ -1105,9 +1110,10 @@ class LinearTransformation(torch.nn.Module):
)
)
flat_tensor
=
tensor
.
view
(
-
1
,
n
)
-
self
.
mean_vector
flat_tensor
=
tensor
.
view
(
-
1
,
n
)
-
self
.
mean_vector
transformed_tensor
=
torch
.
mm
(
flat_tensor
,
self
.
transformation_matrix
)
tensor
=
transformed_tensor
.
view
(
shape
)
transformation_matrix
=
self
.
transformation_matrix
.
to
(
flat_tensor
.
dtype
)
return
tensor
transformed_tensor
=
torch
.
mm
(
flat_tensor
,
transformation_matrix
)
return
transformed_tensor
.
view
(
shape
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
s
=
(
s
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment