Unverified Commit 6e72f2fd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add seeds on Kernel Info and reduce randomness for Gaussian Blur (#6741)

* Add seeds on Kernel Info and reduce randomness for Gaussian Blur

* Fix linter
parent 4d4711d9
......@@ -49,12 +49,14 @@ class KernelInfo(InfoBase):
test_marks=None,
# See InfoBase
closeness_kwargs=None,
seed=None,
):
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.kernel = kernel
self.sample_inputs_fn = sample_inputs_fn
self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn
self.seed = seed
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
......@@ -1304,7 +1306,7 @@ KERNEL_INFOS.extend(
def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB]
make_image_loaders, sizes=[(7, 33)], color_spaces=[features.ColorSpace.RGB]
)
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
......@@ -1317,7 +1319,7 @@ def sample_inputs_gaussian_blur_image_tensor():
def sample_inputs_gaussian_blur_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
yield ArgsKwargs(video_loader, kernel_size=[3, 3])
......@@ -1331,10 +1333,13 @@ KERNEL_INFOS.extend(
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
seed=0,
),
KernelInfo(
F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
seed=0,
),
]
)
......
......@@ -6,7 +6,7 @@ import PIL.Image
import pytest
import torch
from common_utils import cache, cpu_and_gpu, needs_cuda
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
......@@ -81,6 +81,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
kernel_eager = info.kernel
kernel_scripted = script(kernel_eager)
......@@ -111,6 +113,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
......@@ -146,6 +150,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
(input, *other_args), kwargs = args_kwargs.load(device)
if input.numel() == 0:
......@@ -159,6 +165,8 @@ class TestKernels:
@sample_inputs
@needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs):
if info.seed is not None:
set_rng_seed(info.seed)
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda")
......@@ -170,6 +178,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
(input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs)
......@@ -182,6 +192,8 @@ class TestKernels:
@reference_inputs
def test_against_reference(self, info, args_kwargs):
if info.seed is not None:
set_rng_seed(info.seed)
args, kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment