Unverified Commit c206a471 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add reference test for normalize_image_tensor (#7119)

parent d2d448c7
......@@ -2232,6 +2232,22 @@ def sample_inputs_normalize_image_tensor():
yield ArgsKwargs(image_loader, mean=mean, std=std)
def reference_normalize_image_tensor(image, mean, std, inplace=False):
mean = torch.tensor(mean).view(-1, 1, 1)
std = torch.tensor(std).view(-1, 1, 1)
sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
return sub(image, mean).div_(std)
def reference_inputs_normalize_image_tensor():
yield ArgsKwargs(
make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]),
mean=[0.5, 0.5, 0.5],
std=[1.0, 1.0, 1.0],
)
def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders(
......@@ -2246,6 +2262,8 @@ KERNEL_INFOS.extend(
F.normalize_image_tensor,
kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor,
reference_fn=reference_normalize_image_tensor,
reference_inputs_fn=reference_inputs_normalize_image_tensor,
test_marks=[
xfail_jit_python_scalar_arg("mean"),
xfail_jit_python_scalar_arg("std"),
......
......@@ -13,7 +13,12 @@ import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
from prototype_common_utils import (
assert_close,
DEFAULT_SQUARE_SPATIAL_SIZE,
make_bounding_boxes,
parametrized_error_message,
)
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
......@@ -538,6 +543,22 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
assert output.device == input.device
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")
def assert_samples_from_standard_normal(t):
p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
return p_value > 1e-4
image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
mean = image.mean(dim=(1, 2)).tolist()
std = image.std(dim=(1, 2)).tolist()
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment