Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
99ec261c
Unverified
Commit
99ec261c
authored
May 16, 2023
by
vfdev
Committed by
GitHub
May 16, 2023
Browse files
Resize V2 relies on interpolate's native uint8 handling (#7557)
Co-authored-by:
Nicolas Hug
<
nh.nicolas.hug@gmail.com
>
parent
fc838add
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
22 deletions
+111
-22
test/common_utils.py
test/common_utils.py
+27
-10
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+10
-3
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+30
-0
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+11
-7
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+33
-2
No files found.
test/common_utils.py
View file @
99ec261c
...
@@ -465,11 +465,15 @@ class TensorLoader:
...
@@ -465,11 +465,15 @@ class TensorLoader:
class
ImageLoader
(
TensorLoader
):
class
ImageLoader
(
TensorLoader
):
spatial_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
spatial_size
:
Tuple
[
int
,
int
]
=
dataclasses
.
field
(
init
=
False
)
num_channels
:
int
=
dataclasses
.
field
(
init
=
False
)
num_channels
:
int
=
dataclasses
.
field
(
init
=
False
)
memory_format
:
torch
.
memory_format
=
torch
.
contiguous_format
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
spatial_size
=
self
.
shape
[
-
2
:]
self
.
spatial_size
=
self
.
shape
[
-
2
:]
self
.
num_channels
=
self
.
shape
[
-
3
]
self
.
num_channels
=
self
.
shape
[
-
3
]
def
load
(
self
,
device
):
return
self
.
fn
(
self
.
shape
,
self
.
dtype
,
device
,
memory_format
=
self
.
memory_format
)
NUM_CHANNELS_MAP
=
{
NUM_CHANNELS_MAP
=
{
"GRAY"
:
1
,
"GRAY"
:
1
,
...
@@ -493,18 +497,21 @@ def make_image_loader(
...
@@ -493,18 +497,21 @@ def make_image_loader(
extra_dims
=
(),
extra_dims
=
(),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
constant_alpha
=
True
,
constant_alpha
=
True
,
memory_format
=
torch
.
contiguous_format
,
):
):
size
=
_parse_spatial_size
(
size
)
size
=
_parse_spatial_size
(
size
)
num_channels
=
get_num_channels
(
color_space
)
num_channels
=
get_num_channels
(
color_space
)
def
fn
(
shape
,
dtype
,
device
):
def
fn
(
shape
,
dtype
,
device
,
memory_format
):
max_value
=
get_max_value
(
dtype
)
max_value
=
get_max_value
(
dtype
)
data
=
torch
.
testing
.
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
,
device
=
device
)
data
=
torch
.
testing
.
make_tensor
(
shape
,
low
=
0
,
high
=
max_value
,
dtype
=
dtype
,
device
=
device
,
memory_format
=
memory_format
)
if
color_space
in
{
"GRAY_ALPHA"
,
"RGBA"
}
and
constant_alpha
:
if
color_space
in
{
"GRAY_ALPHA"
,
"RGBA"
}
and
constant_alpha
:
data
[...,
-
1
,
:,
:]
=
max_value
data
[...,
-
1
,
:,
:]
=
max_value
return
datapoints
.
Image
(
data
)
return
datapoints
.
Image
(
data
)
return
ImageLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_channels
,
*
size
),
dtype
=
dtype
)
return
ImageLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_channels
,
*
size
),
dtype
=
dtype
,
memory_format
=
memory_format
)
make_image
=
from_loader
(
make_image_loader
)
make_image
=
from_loader
(
make_image_loader
)
...
@@ -530,11 +537,13 @@ def make_image_loaders(
...
@@ -530,11 +537,13 @@ def make_image_loaders(
make_images
=
from_loaders
(
make_image_loaders
)
make_images
=
from_loaders
(
make_image_loaders
)
def
make_image_loader_for_interpolation
(
size
=
"random"
,
*
,
color_space
=
"RGB"
,
dtype
=
torch
.
uint8
):
def
make_image_loader_for_interpolation
(
size
=
"random"
,
*
,
color_space
=
"RGB"
,
dtype
=
torch
.
uint8
,
memory_format
=
torch
.
contiguous_format
):
size
=
_parse_spatial_size
(
size
)
size
=
_parse_spatial_size
(
size
)
num_channels
=
get_num_channels
(
color_space
)
num_channels
=
get_num_channels
(
color_space
)
def
fn
(
shape
,
dtype
,
device
):
def
fn
(
shape
,
dtype
,
device
,
memory_format
):
height
,
width
=
shape
[
-
2
:]
height
,
width
=
shape
[
-
2
:]
image_pil
=
(
image_pil
=
(
...
@@ -550,19 +559,25 @@ def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dty
...
@@ -550,19 +559,25 @@ def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dty
)
)
)
)
image_tensor
=
convert_dtype_image_tensor
(
to_image_tensor
(
image_pil
).
to
(
device
=
device
),
dtype
=
dtype
)
image_tensor
=
to_image_tensor
(
image_pil
)
if
memory_format
==
torch
.
contiguous_format
:
image_tensor
=
image_tensor
.
to
(
device
=
device
,
memory_format
=
memory_format
,
copy
=
True
)
else
:
image_tensor
=
image_tensor
.
to
(
device
=
device
)
image_tensor
=
convert_dtype_image_tensor
(
image_tensor
,
dtype
=
dtype
)
return
datapoints
.
Image
(
image_tensor
)
return
datapoints
.
Image
(
image_tensor
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
)
return
ImageLoader
(
fn
,
shape
=
(
num_channels
,
*
size
),
dtype
=
dtype
,
memory_format
=
memory_format
)
def
make_image_loaders_for_interpolation
(
def
make_image_loaders_for_interpolation
(
sizes
=
((
233
,
147
),),
sizes
=
((
233
,
147
),),
color_spaces
=
(
"RGB"
,),
color_spaces
=
(
"RGB"
,),
dtypes
=
(
torch
.
uint8
,),
dtypes
=
(
torch
.
uint8
,),
memory_formats
=
(
torch
.
contiguous_format
,
torch
.
channels_last
),
):
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
dtype
=
dtypes
):
for
params
in
combinations_grid
(
size
=
sizes
,
color_space
=
color_spaces
,
dtype
=
dtypes
,
memory_format
=
memory_formats
):
yield
make_image_loader_for_interpolation
(
**
params
)
yield
make_image_loader_for_interpolation
(
**
params
)
...
@@ -744,8 +759,10 @@ def make_video_loader(
...
@@ -744,8 +759,10 @@ def make_video_loader(
size
=
_parse_spatial_size
(
size
)
size
=
_parse_spatial_size
(
size
)
num_frames
=
int
(
torch
.
randint
(
1
,
5
,
()))
if
num_frames
==
"random"
else
num_frames
num_frames
=
int
(
torch
.
randint
(
1
,
5
,
()))
if
num_frames
==
"random"
else
num_frames
def
fn
(
shape
,
dtype
,
device
):
def
fn
(
shape
,
dtype
,
device
,
memory_format
):
video
=
make_image
(
size
=
shape
[
-
2
:],
extra_dims
=
shape
[:
-
3
],
dtype
=
dtype
,
device
=
device
)
video
=
make_image
(
size
=
shape
[
-
2
:],
extra_dims
=
shape
[:
-
3
],
dtype
=
dtype
,
device
=
device
,
memory_format
=
memory_format
)
return
datapoints
.
Video
(
video
)
return
datapoints
.
Video
(
video
)
return
VideoLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_frames
,
get_num_channels
(
color_space
),
*
size
),
dtype
=
dtype
)
return
VideoLoader
(
fn
,
shape
=
(
*
extra_dims
,
num_frames
,
get_num_channels
(
color_space
),
*
size
),
dtype
=
dtype
)
...
...
test/test_transforms_v2_consistency.py
View file @
99ec261c
...
@@ -98,6 +98,8 @@ CONSISTENCY_CONFIGS = [
...
@@ -98,6 +98,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
],
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
),
),
),
ConsistencyConfig
(
ConsistencyConfig
(
v2_transforms
.
CenterCrop
,
v2_transforms
.
CenterCrop
,
...
@@ -313,6 +315,8 @@ CONSISTENCY_CONFIGS = [
...
@@ -313,6 +315,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
29
,
32
),
antialias
=
False
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
ArgsKwargs
((
28
,
31
),
antialias
=
True
),
],
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
),
),
),
ConsistencyConfig
(
ConsistencyConfig
(
v2_transforms
.
RandomErasing
,
v2_transforms
.
RandomErasing
,
...
@@ -783,7 +787,8 @@ class TestContainerTransforms:
...
@@ -783,7 +787,8 @@ class TestContainerTransforms:
]
]
)
)
check_call_consistency
(
prototype_transform
,
legacy_transform
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency
(
prototype_transform
,
legacy_transform
,
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
))
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
0.1
,
0.5
,
0.9
,
1
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
0.1
,
0.5
,
0.9
,
1
])
@
pytest
.
mark
.
parametrize
(
"sequence_type"
,
[
list
,
nn
.
ModuleList
])
@
pytest
.
mark
.
parametrize
(
"sequence_type"
,
[
list
,
nn
.
ModuleList
])
...
@@ -807,7 +812,8 @@ class TestContainerTransforms:
...
@@ -807,7 +812,8 @@ class TestContainerTransforms:
p
=
p
,
p
=
p
,
)
)
check_call_consistency
(
prototype_transform
,
legacy_transform
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency
(
prototype_transform
,
legacy_transform
,
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
))
if
sequence_type
is
nn
.
ModuleList
:
if
sequence_type
is
nn
.
ModuleList
:
# quick and dirty test that it is jit-scriptable
# quick and dirty test that it is jit-scriptable
...
@@ -832,7 +838,8 @@ class TestContainerTransforms:
...
@@ -832,7 +838,8 @@ class TestContainerTransforms:
p
=
probabilities
,
p
=
probabilities
,
)
)
check_call_consistency
(
prototype_transform
,
legacy_transform
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency
(
prototype_transform
,
legacy_transform
,
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
))
class
TestToTensorTransforms
:
class
TestToTensorTransforms
:
...
...
test/test_transforms_v2_functional.py
View file @
99ec261c
...
@@ -1365,3 +1365,33 @@ def test_correctness_uniform_temporal_subsample(device):
...
@@ -1365,3 +1365,33 @@ def test_correctness_uniform_temporal_subsample(device):
out_video
=
F
.
uniform_temporal_subsample
(
video
,
8
)
out_video
=
F
.
uniform_temporal_subsample
(
video
,
8
)
assert
out_video
.
unique
().
tolist
()
==
[
0
,
1
,
2
,
3
,
5
,
6
,
7
,
9
]
assert
out_video
.
unique
().
tolist
()
==
[
0
,
1
,
2
,
3
,
5
,
6
,
7
,
9
]
# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
KERNEL_INFOS
if
info
.
kernel
is
F
.
resize_image_tensor
],
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
)
def
test_memory_format_consistency_resize_image_tensor
(
test_id
,
info
,
args_kwargs
):
(
input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
"cpu"
)
output
=
info
.
kernel
(
input
.
as_subclass
(
torch
.
Tensor
),
*
other_args
,
**
kwargs
)
error_msg_fn
=
parametrized_error_message
(
input
,
*
other_args
,
**
kwargs
)
assert
input
.
ndim
==
3
,
error_msg_fn
input_stride
=
input
.
stride
()
output_stride
=
output
.
stride
()
# Here we check output memory format according to the input:
# if input_stride is (..., 1) then input is most likely channels first and thus
# output strides should match channels first strides (H * W, H, 1)
# if input_stride is (1, ...) then input is most likely channels last and thus
# output strides should match channels last strides (1, W * C, C)
if
input_stride
[
-
1
]
==
1
:
expected_stride
=
(
output
.
shape
[
-
2
]
*
output
.
shape
[
-
1
],
output
.
shape
[
-
1
],
1
)
assert
expected_stride
==
output_stride
,
error_msg_fn
(
""
)
elif
input_stride
[
0
]
==
1
:
expected_stride
=
(
1
,
output
.
shape
[
0
]
*
output
.
shape
[
-
1
],
output
.
shape
[
0
])
assert
expected_stride
==
output_stride
,
error_msg_fn
(
""
)
else
:
assert
False
,
error_msg_fn
(
""
)
test/transforms_v2_kernel_infos.py
View file @
99ec261c
...
@@ -1569,7 +1569,7 @@ def reference_inputs_equalize_image_tensor():
...
@@ -1569,7 +1569,7 @@ def reference_inputs_equalize_image_tensor():
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
# the information gain is low if we already provide something really close to the expected value.
def
make_uniform_band_image
(
shape
,
dtype
,
device
,
*
,
low_factor
,
high_factor
):
def
make_uniform_band_image
(
shape
,
dtype
,
device
,
*
,
low_factor
,
high_factor
,
memory_format
):
if
dtype
.
is_floating_point
:
if
dtype
.
is_floating_point
:
low
=
low_factor
low
=
low_factor
high
=
high_factor
high
=
high_factor
...
@@ -1577,23 +1577,27 @@ def reference_inputs_equalize_image_tensor():
...
@@ -1577,23 +1577,27 @@ def reference_inputs_equalize_image_tensor():
max_value
=
torch
.
iinfo
(
dtype
).
max
max_value
=
torch
.
iinfo
(
dtype
).
max
low
=
int
(
low_factor
*
max_value
)
low
=
int
(
low_factor
*
max_value
)
high
=
int
(
high_factor
*
max_value
)
high
=
int
(
high_factor
*
max_value
)
return
torch
.
testing
.
make_tensor
(
shape
,
dtype
=
dtype
,
device
=
device
,
low
=
low
,
high
=
high
)
return
torch
.
testing
.
make_tensor
(
shape
,
dtype
=
dtype
,
device
=
device
,
low
=
low
,
high
=
high
).
to
(
memory_format
=
memory_format
,
copy
=
True
)
def
make_beta_distributed_image
(
shape
,
dtype
,
device
,
*
,
alpha
,
beta
):
def
make_beta_distributed_image
(
shape
,
dtype
,
device
,
*
,
alpha
,
beta
,
memory_format
):
image
=
torch
.
distributions
.
Beta
(
alpha
,
beta
).
sample
(
shape
)
image
=
torch
.
distributions
.
Beta
(
alpha
,
beta
).
sample
(
shape
)
if
not
dtype
.
is_floating_point
:
if
not
dtype
.
is_floating_point
:
image
.
mul_
(
torch
.
iinfo
(
dtype
).
max
).
round_
()
image
.
mul_
(
torch
.
iinfo
(
dtype
).
max
).
round_
()
return
image
.
to
(
dtype
=
dtype
,
device
=
device
)
return
image
.
to
(
dtype
=
dtype
,
device
=
device
,
memory_format
=
memory_format
,
copy
=
True
)
spatial_size
=
(
256
,
256
)
spatial_size
=
(
256
,
256
)
for
dtype
,
color_space
,
fn
in
itertools
.
product
(
for
dtype
,
color_space
,
fn
in
itertools
.
product
(
[
torch
.
uint8
],
[
torch
.
uint8
],
[
"GRAY"
,
"RGB"
],
[
"GRAY"
,
"RGB"
],
[
[
lambda
shape
,
dtype
,
device
:
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
),
lambda
shape
,
dtype
,
device
,
memory_format
:
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
).
to
(
lambda
shape
,
dtype
,
device
:
torch
.
full
(
memory_format
=
memory_format
,
copy
=
True
shape
,
1.0
if
dtype
.
is_floating_point
else
torch
.
iinfo
(
dtype
).
max
,
dtype
=
dtype
,
device
=
device
),
),
lambda
shape
,
dtype
,
device
,
memory_format
:
torch
.
full
(
shape
,
1.0
if
dtype
.
is_floating_point
else
torch
.
iinfo
(
dtype
).
max
,
dtype
=
dtype
,
device
=
device
).
to
(
memory_format
=
memory_format
,
copy
=
True
),
*
[
*
[
functools
.
partial
(
make_uniform_band_image
,
low_factor
=
low_factor
,
high_factor
=
high_factor
)
functools
.
partial
(
make_uniform_band_image
,
low_factor
=
low_factor
,
high_factor
=
high_factor
)
for
low_factor
,
high_factor
in
[
for
low_factor
,
high_factor
in
[
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
99ec261c
...
@@ -176,16 +176,47 @@ def resize_image_tensor(
...
@@ -176,16 +176,47 @@ def resize_image_tensor(
antialias
=
False
antialias
=
False
shape
=
image
.
shape
shape
=
image
.
shape
numel
=
image
.
numel
()
num_channels
,
old_height
,
old_width
=
shape
[
-
3
:]
num_channels
,
old_height
,
old_width
=
shape
[
-
3
:]
new_height
,
new_width
=
_compute_resized_output_size
((
old_height
,
old_width
),
size
=
size
,
max_size
=
max_size
)
new_height
,
new_width
=
_compute_resized_output_size
((
old_height
,
old_width
),
size
=
size
,
max_size
=
max_size
)
if
(
new_height
,
new_width
)
==
(
old_height
,
old_width
):
if
(
new_height
,
new_width
)
==
(
old_height
,
old_width
):
return
image
return
image
elif
image
.
numel
()
>
0
:
elif
numel
>
0
:
image
=
image
.
reshape
(
-
1
,
num_channels
,
old_height
,
old_width
)
image
=
image
.
reshape
(
-
1
,
num_channels
,
old_height
,
old_width
)
dtype
=
image
.
dtype
dtype
=
image
.
dtype
need_cast
=
dtype
not
in
(
torch
.
float32
,
torch
.
float64
)
acceptable_dtypes
=
[
torch
.
float32
,
torch
.
float64
]
if
interpolation
==
InterpolationMode
.
NEAREST
or
interpolation
==
InterpolationMode
.
NEAREST_EXACT
:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes
.
append
(
torch
.
uint8
)
elif
interpolation
==
InterpolationMode
.
BILINEAR
and
image
.
device
.
type
==
"cpu"
:
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
if
"AVX2"
in
torch
.
backends
.
cpu
.
get_cpu_capability
():
acceptable_dtypes
.
append
(
torch
.
uint8
)
# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
if
dtype
==
torch
.
uint8
and
not
(
image
.
is_contiguous
()
or
image
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
):
image
=
image
.
contiguous
(
memory_format
=
torch
.
channels_last
)
strides
=
image
.
stride
()
if
image
.
is_contiguous
(
memory_format
=
torch
.
channels_last
)
and
image
.
shape
[
0
]
==
1
and
numel
!=
strides
[
0
]:
# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
# contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
# In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
# to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
# channels_last, thus preserving the memory format of the input. This is not just for format consistency:
# for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
# TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
# we should be able to remove this hack.
new_strides
=
list
(
strides
)
new_strides
[
0
]
=
numel
image
=
image
.
as_strided
((
1
,
num_channels
,
old_height
,
old_width
),
new_strides
)
need_cast
=
dtype
not
in
acceptable_dtypes
if
need_cast
:
if
need_cast
:
image
=
image
.
to
(
dtype
=
torch
.
float32
)
image
=
image
.
to
(
dtype
=
torch
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment