Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c206a471
"cpu/arange_interleave.cpp" did not exist on "4a569c27736957e6606fd3b3f69712a808189a74"
Unverified
Commit
c206a471
authored
Jan 23, 2023
by
Philip Meier
Committed by
GitHub
Jan 23, 2023
Browse files
add reference test for normalize_image_tensor (#7119)
parent
d2d448c7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
1 deletion
+40
-1
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+18
-0
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+22
-1
No files found.
test/prototype_transforms_kernel_infos.py
View file @
c206a471
...
...
@@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor():
yield
ArgsKwargs
(
image_loader
,
mean
=
mean
,
std
=
std
)
def
reference_normalize_image_tensor
(
image
,
mean
,
std
,
inplace
=
False
):
mean
=
torch
.
tensor
(
mean
).
view
(
-
1
,
1
,
1
)
std
=
torch
.
tensor
(
std
).
view
(
-
1
,
1
,
1
)
sub
=
torch
.
Tensor
.
sub_
if
inplace
else
torch
.
Tensor
.
sub
return
sub
(
image
,
mean
).
div_
(
std
)
def
reference_inputs_normalize_image_tensor
():
yield
ArgsKwargs
(
make_image_loader
(
size
=
(
32
,
32
),
color_space
=
datapoints
.
ColorSpace
.
RGB
,
extra_dims
=
[
1
]),
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
1.0
,
1.0
,
1.0
],
)
def
sample_inputs_normalize_video
():
mean
,
std
=
_NORMALIZE_MEANS_STDS
[
0
]
for
video_loader
in
make_video_loaders
(
...
...
@@ -2246,6 +2262,8 @@ KERNEL_INFOS.extend(
F
.
normalize_image_tensor
,
kernel_name
=
"normalize_image_tensor"
,
sample_inputs_fn
=
sample_inputs_normalize_image_tensor
,
reference_fn
=
reference_normalize_image_tensor
,
reference_inputs_fn
=
reference_inputs_normalize_image_tensor
,
test_marks
=
[
xfail_jit_python_scalar_arg
(
"mean"
),
xfail_jit_python_scalar_arg
(
"std"
),
...
...
test/test_prototype_transforms_functional.py
View file @
c206a471
...
...
@@ -13,7 +13,12 @@ import torch
import
torchvision.prototype.transforms.utils
from
common_utils
import
cache
,
cpu_and_gpu
,
needs_cuda
,
set_rng_seed
from
prototype_common_utils
import
assert_close
,
make_bounding_boxes
,
parametrized_error_message
from
prototype_common_utils
import
(
assert_close
,
DEFAULT_SQUARE_SPATIAL_SIZE
,
make_bounding_boxes
,
parametrized_error_message
,
)
from
prototype_transforms_dispatcher_infos
import
DISPATCHER_INFOS
from
prototype_transforms_kernel_infos
import
KERNEL_INFOS
from
torch.utils._pytree
import
tree_map
...
...
@@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
assert
output
.
device
==
input
.
device
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"num_channels"
,
[
1
,
3
])
def
test_normalize_image_tensor_stats
(
device
,
num_channels
):
stats
=
pytest
.
importorskip
(
"scipy.stats"
,
reason
=
"SciPy is not available"
)
def
assert_samples_from_standard_normal
(
t
):
p_value
=
stats
.
kstest
(
t
.
flatten
(),
cdf
=
"norm"
,
args
=
(
0
,
1
)).
pvalue
return
p_value
>
1e-4
image
=
torch
.
rand
(
num_channels
,
DEFAULT_SQUARE_SPATIAL_SIZE
,
DEFAULT_SQUARE_SPATIAL_SIZE
)
mean
=
image
.
mean
(
dim
=
(
1
,
2
)).
tolist
()
std
=
image
.
std
(
dim
=
(
1
,
2
)).
tolist
()
assert_samples_from_standard_normal
(
F
.
normalize_image_tensor
(
image
,
mean
,
std
))
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment