"vscode:/vscode.git/clone" did not exist on "80b22ad881b6be61e49179940599614f47724553"
Unverified Commit 6e72f2fd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add seeds on Kernel Info and reduce randomness for Gaussian Blur (#6741)

* Add seeds on Kernel Info and reduce randomness for Gaussian Blur

* Fix linter
parent 4d4711d9
......@@ -49,12 +49,14 @@ class KernelInfo(InfoBase):
test_marks=None,
# See InfoBase
closeness_kwargs=None,
seed=None,
):
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.kernel = kernel
self.sample_inputs_fn = sample_inputs_fn
self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn
self.seed = seed
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
......@@ -1304,7 +1306,7 @@ KERNEL_INFOS.extend(
def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB]
make_image_loaders, sizes=[(7, 33)], color_spaces=[features.ColorSpace.RGB]
)
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
......@@ -1317,7 +1319,7 @@ def sample_inputs_gaussian_blur_image_tensor():
def sample_inputs_gaussian_blur_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
yield ArgsKwargs(video_loader, kernel_size=[3, 3])
......@@ -1331,10 +1333,13 @@ KERNEL_INFOS.extend(
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
seed=0,
),
KernelInfo(
F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
seed=0,
),
]
)
......
......@@ -6,7 +6,7 @@ import PIL.Image
import pytest
import torch
from common_utils import cache, cpu_and_gpu, needs_cuda
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
......@@ -81,6 +81,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
kernel_eager = info.kernel
kernel_scripted = script(kernel_eager)
......@@ -111,6 +113,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
......@@ -146,6 +150,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_no_inplace(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
(input, *other_args), kwargs = args_kwargs.load(device)
if input.numel() == 0:
......@@ -159,6 +165,8 @@ class TestKernels:
@sample_inputs
@needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs):
if info.seed is not None:
set_rng_seed(info.seed)
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda")
......@@ -170,6 +178,8 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
if info.seed is not None:
set_rng_seed(info.seed)
(input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs)
......@@ -182,6 +192,8 @@ class TestKernels:
@reference_inputs
def test_against_reference(self, info, args_kwargs):
if info.seed is not None:
set_rng_seed(info.seed)
args, kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment