Unverified Commit b80f83db authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add support for fine-grained tolerance settings (#6921)

* add support for fine-grained tolerance settings

* fix test_cuda_vs_cpu
parent e3f7baaf
......@@ -628,21 +628,34 @@ def mark_framework_limitation(test_id, reason):
class InfoBase:
def __init__(self, *, id, test_marks=None, closeness_kwargs=None):
def __init__(
self,
*,
# Identifier if the info that shows up the parametrization.
self.id = id
id,
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details
self.test_marks = test_marks or []
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
self.closeness_kwargs = closeness_kwargs or dict()
test_marks=None,
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. Keys are a 3-tuple of `test_id` (see
# `TestMark`), the dtype, and the device.
closeness_kwargs=None,
):
self.id = id
self.test_marks = test_marks or []
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)
self.closeness_kwargs = closeness_kwargs or dict()
def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
def get_closeness_kwargs(self, test_id, *, dtype, device):
if isinstance(device, torch.device):
device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict())
......@@ -61,11 +61,10 @@ class KernelInfo(InfoBase):
self.reference_inputs_fn = reference_inputs_fn
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
atol=1e-5,
rtol=0,
agg_method="mean",
)
DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
}
def pil_reference_wrapper(pil_kernel):
......@@ -176,7 +175,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.horizontal_flip_bounding_box,
......@@ -320,7 +319,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
......@@ -339,7 +338,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_mask,
reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
......@@ -556,7 +555,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_image_tensor,
reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
......@@ -569,7 +568,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0),
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
test_marks=[
xfail_jit_python_scalar_arg("shear"),
],
......@@ -579,7 +580,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_mask,
reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("shear"),
],
......@@ -668,7 +669,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.convert_color_space_video,
......@@ -729,7 +730,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.vertical_flip_bounding_box,
......@@ -820,7 +821,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_image_tensor,
reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
......@@ -836,7 +837,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_mask,
reference_fn=reference_rotate_mask,
reference_inputs_fn=reference_inputs_rotate_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.rotate_video,
......@@ -918,7 +919,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.crop_bounding_box,
......@@ -931,7 +932,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.crop_video,
......@@ -1010,7 +1011,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.resized_crop_bounding_box,
......@@ -1021,7 +1022,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_mask,
reference_fn=pil_reference_wrapper(F.resized_crop_image_pil),
reference_inputs_fn=reference_inputs_resized_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.resized_crop_video,
......@@ -1144,7 +1145,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
......@@ -1166,7 +1167,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.pad_video,
......@@ -1225,7 +1226,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.perspective_bounding_box,
......@@ -1236,7 +1237,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.perspective_video,
......@@ -1306,7 +1307,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.elastic_bounding_box,
......@@ -1317,7 +1318,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_mask,
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.elastic_video,
......@@ -1387,7 +1388,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
......@@ -1404,7 +1405,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("output_size"),
],
......@@ -1441,7 +1442,7 @@ KERNEL_INFOS.extend(
KernelInfo(
F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
......@@ -1529,7 +1530,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil),
reference_inputs_fn=reference_inputs_equalize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.equalize_video,
......@@ -1566,7 +1567,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.invert_video,
......@@ -1607,7 +1608,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.posterize_video,
......@@ -1651,7 +1652,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.solarize_video,
......@@ -1688,7 +1689,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.autocontrast_video,
......@@ -1729,7 +1730,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_sharpness_video,
......@@ -1800,7 +1801,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_brightness_video,
......@@ -1841,7 +1842,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_contrast_video,
......@@ -1886,7 +1887,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_gamma_video,
......@@ -1927,7 +1928,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_hue_video,
......@@ -1967,7 +1968,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_saturation_video,
......@@ -2061,7 +2062,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.five_crop_video,
......@@ -2074,7 +2075,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
),
KernelInfo(
F.ten_crop_video,
......
......@@ -93,6 +93,13 @@ def fix_rng_seed():
yield
@pytest.fixture()
def test_id(request):
test_class_name = request.cls.__name__ if request.cls is not None else None
test_function_name = request.node.originalname
return test_class_name, test_function_name
class TestKernels:
sample_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS,
......@@ -107,16 +114,20 @@ class TestKernels:
@ignore_jit_warning_no_profile
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device):
def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
kernel_eager = info.kernel
kernel_scripted = script(kernel_eager)
args, kwargs = args_kwargs.load(device)
(input, *other_args), kwargs = args_kwargs.load(device)
actual = kernel_scripted(*args, **kwargs)
expected = kernel_eager(*args, **kwargs)
actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(input, *other_args, **kwargs)
assert_close(actual, expected, **info.closeness_kwargs)
assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
)
def _unbatch(self, batch, *, data_dims):
if isinstance(batch, torch.Tensor):
......@@ -137,7 +148,7 @@ class TestKernels:
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device):
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
......@@ -168,7 +179,11 @@ class TestKernels:
single_inputs = self._unbatch(batched_input, data_dims=data_dims)
expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
assert_close(actual, expected, **info.closeness_kwargs)
assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -185,14 +200,19 @@ class TestKernels:
@sample_inputs
@needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs):
def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs)
assert_close(
output_cuda,
output_cpu,
check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
)
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -208,13 +228,18 @@ class TestKernels:
assert output.device == input.device
@reference_inputs
def test_against_reference(self, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu")
def test_against_reference(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs)
expected = info.reference_fn(*args, **kwargs)
actual = info.kernel(input, *other_args, **kwargs)
expected = info.reference_fn(input, *other_args, **kwargs)
assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs)
assert_close(
actual,
expected,
check_dtype=False,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
)
@pytest.fixture
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment