Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c206a471
Unverified
Commit
c206a471
authored
Jan 23, 2023
by
Philip Meier
Committed by
GitHub
Jan 23, 2023
Browse files
add reference test for normalize_image_tensor (#7119)
parent
d2d448c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
1 deletion
+40
-1
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+18
-0
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+22
-1
No files found.
test/prototype_transforms_kernel_infos.py
View file @
c206a471
...
...
@@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor():
yield
ArgsKwargs
(
image_loader
,
mean
=
mean
,
std
=
std
)
def
reference_normalize_image_tensor
(
image
,
mean
,
std
,
inplace
=
False
):
mean
=
torch
.
tensor
(
mean
).
view
(
-
1
,
1
,
1
)
std
=
torch
.
tensor
(
std
).
view
(
-
1
,
1
,
1
)
sub
=
torch
.
Tensor
.
sub_
if
inplace
else
torch
.
Tensor
.
sub
return
sub
(
image
,
mean
).
div_
(
std
)
def
reference_inputs_normalize_image_tensor
():
yield
ArgsKwargs
(
make_image_loader
(
size
=
(
32
,
32
),
color_space
=
datapoints
.
ColorSpace
.
RGB
,
extra_dims
=
[
1
]),
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
1.0
,
1.0
,
1.0
],
)
def
sample_inputs_normalize_video
():
mean
,
std
=
_NORMALIZE_MEANS_STDS
[
0
]
for
video_loader
in
make_video_loaders
(
...
...
@@ -2246,6 +2262,8 @@ KERNEL_INFOS.extend(
F
.
normalize_image_tensor
,
kernel_name
=
"normalize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_normalize_image_tensor
,
reference_fn
=
reference_normalize_image_tensor
,
reference_inputs_fn
=
reference_inputs_normalize_image_tensor
,
test_marks
=
[
xfail_jit_python_scalar_arg
(
"mean"
),
xfail_jit_python_scalar_arg
(
"std"
),
...
...
test/test_prototype_transforms_functional.py
View file @
c206a471
...
...
@@ -13,7 +13,12 @@ import torch
import
torchvision.prototype.transforms.utils
from
common_utils
import
cache
,
cpu_and_gpu
,
needs_cuda
,
set_rng_seed
from
prototype_common_utils
import
assert_close
,
make_bounding_boxes
,
parametrized_error_message
from
prototype_common_utils
import
(
assert_close
,
DEFAULT_SQUARE_SPATIAL_SIZE
,
make_bounding_boxes
,
parametrized_error_message
,
)
from
prototype_transforms_dispatcher_infos
import
DISPATCHER_INFOS
from
prototype_transforms_kernel_infos
import
KERNEL_INFOS
from
torch.utils._pytree
import
tree_map
...
...
@@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
assert
output
.
device
==
input
.
device
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"num_channels"
,
[
1
,
3
])
def
test_normalize_image_tensor_stats
(
device
,
num_channels
):
stats
=
pytest
.
importorskip
(
"scipy.stats"
,
reason
=
"SciPy is not available"
)
def
assert_samples_from_standard_normal
(
t
):
p_value
=
stats
.
kstest
(
t
.
flatten
(),
cdf
=
"norm"
,
args
=
(
0
,
1
)).
pvalue
return
p_value
>
1e-4
image
=
torch
.
rand
(
num_channels
,
DEFAULT_SQUARE_SPATIAL_SIZE
,
DEFAULT_SQUARE_SPATIAL_SIZE
)
mean
=
image
.
mean
(
dim
=
(
1
,
2
)).
tolist
()
std
=
image
.
std
(
dim
=
(
1
,
2
)).
tolist
()
assert_samples_from_standard_normal
(
F
.
normalize_image_tensor
(
image
,
mean
,
std
))
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment