Unverified Commit b80f83db authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add support for fine-grained tolerance settings (#6921)

* add support for fine-grained tolerance settings

* fix test_cuda_vs_cpu
parent e3f7baaf
...@@ -628,21 +628,34 @@ def mark_framework_limitation(test_id, reason): ...@@ -628,21 +628,34 @@ def mark_framework_limitation(test_id, reason):
class InfoBase: class InfoBase:
def __init__(self, *, id, test_marks=None, closeness_kwargs=None): def __init__(
self,
*,
# Identifier if the info that shows up the parametrization. # Identifier if the info that shows up the parametrization.
self.id = id id,
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization. # Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details # See the `TestMark` class for details
self.test_marks = test_marks or [] test_marks=None,
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. Keys are a 3-tuple of `test_id` (see
self.closeness_kwargs = closeness_kwargs or dict() # `TestMark`), the dtype, and the device.
closeness_kwargs=None,
):
self.id = id
self.test_marks = test_marks or []
test_marks_map = defaultdict(list) test_marks_map = defaultdict(list)
for test_mark in self.test_marks: for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark) test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map) self._test_marks_map = dict(test_marks_map)
self.closeness_kwargs = closeness_kwargs or dict()
def get_marks(self, test_id, args_kwargs): def get_marks(self, test_id, args_kwargs):
return [ return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs) test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
] ]
def get_closeness_kwargs(self, test_id, *, dtype, device):
if isinstance(device, torch.device):
device = device.type
return self.closeness_kwargs.get((test_id, dtype, device), dict())
...@@ -61,11 +61,10 @@ class KernelInfo(InfoBase): ...@@ -61,11 +61,10 @@ class KernelInfo(InfoBase):
self.reference_inputs_fn = reference_inputs_fn self.reference_inputs_fn = reference_inputs_fn
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
atol=1e-5, (("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
rtol=0, (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
agg_method="mean", }
)
def pil_reference_wrapper(pil_kernel): def pil_reference_wrapper(pil_kernel):
...@@ -176,7 +175,7 @@ KERNEL_INFOS.extend( ...@@ -176,7 +175,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor, sample_inputs_fn=sample_inputs_horizontal_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil), reference_fn=pil_reference_wrapper(F.horizontal_flip_image_pil),
reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor, reference_inputs_fn=reference_inputs_horizontal_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.horizontal_flip_bounding_box, F.horizontal_flip_bounding_box,
...@@ -320,7 +319,7 @@ KERNEL_INFOS.extend( ...@@ -320,7 +319,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_image_tensor, sample_inputs_fn=sample_inputs_resize_image_tensor,
reference_fn=reference_resize_image_tensor, reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor, reference_inputs_fn=reference_inputs_resize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
], ],
...@@ -339,7 +338,7 @@ KERNEL_INFOS.extend( ...@@ -339,7 +338,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resize_mask, sample_inputs_fn=sample_inputs_resize_mask,
reference_fn=reference_resize_mask, reference_fn=reference_resize_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
], ],
...@@ -556,7 +555,7 @@ KERNEL_INFOS.extend( ...@@ -556,7 +555,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_image_tensor, sample_inputs_fn=sample_inputs_affine_image_tensor,
reference_fn=pil_reference_wrapper(F.affine_image_pil), reference_fn=pil_reference_wrapper(F.affine_image_pil),
reference_inputs_fn=reference_inputs_affine_image_tensor, reference_inputs_fn=reference_inputs_affine_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
...@@ -569,7 +568,9 @@ KERNEL_INFOS.extend( ...@@ -569,7 +568,9 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_bounding_box, sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box, reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0), closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
], ],
...@@ -579,7 +580,7 @@ KERNEL_INFOS.extend( ...@@ -579,7 +580,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_mask, sample_inputs_fn=sample_inputs_affine_mask,
reference_fn=reference_affine_mask, reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
], ],
...@@ -668,7 +669,7 @@ KERNEL_INFOS.extend( ...@@ -668,7 +669,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor, reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.convert_color_space_video, F.convert_color_space_video,
...@@ -729,7 +730,7 @@ KERNEL_INFOS.extend( ...@@ -729,7 +730,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor, sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil), reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor, reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.vertical_flip_bounding_box, F.vertical_flip_bounding_box,
...@@ -820,7 +821,7 @@ KERNEL_INFOS.extend( ...@@ -820,7 +821,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_image_tensor, sample_inputs_fn=sample_inputs_rotate_image_tensor,
reference_fn=pil_reference_wrapper(F.rotate_image_pil), reference_fn=pil_reference_wrapper(F.rotate_image_pil),
reference_inputs_fn=reference_inputs_rotate_image_tensor, reference_inputs_fn=reference_inputs_rotate_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok # TODO: check if this is a regression since it seems that should be supported if `int` is ok
...@@ -836,7 +837,7 @@ KERNEL_INFOS.extend( ...@@ -836,7 +837,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_rotate_mask, sample_inputs_fn=sample_inputs_rotate_mask,
reference_fn=reference_rotate_mask, reference_fn=reference_rotate_mask,
reference_inputs_fn=reference_inputs_rotate_mask, reference_inputs_fn=reference_inputs_rotate_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.rotate_video, F.rotate_video,
...@@ -918,7 +919,7 @@ KERNEL_INFOS.extend( ...@@ -918,7 +919,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_image_tensor, sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor, reference_inputs_fn=reference_inputs_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.crop_bounding_box, F.crop_bounding_box,
...@@ -931,7 +932,7 @@ KERNEL_INFOS.extend( ...@@ -931,7 +932,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_crop_mask, sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F.crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask, reference_inputs_fn=reference_inputs_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.crop_video, F.crop_video,
...@@ -1010,7 +1011,7 @@ KERNEL_INFOS.extend( ...@@ -1010,7 +1011,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_image_tensor, sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor, reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor, reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.resized_crop_bounding_box, F.resized_crop_bounding_box,
...@@ -1021,7 +1022,7 @@ KERNEL_INFOS.extend( ...@@ -1021,7 +1022,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_resized_crop_mask, sample_inputs_fn=sample_inputs_resized_crop_mask,
reference_fn=pil_reference_wrapper(F.resized_crop_image_pil), reference_fn=pil_reference_wrapper(F.resized_crop_image_pil),
reference_inputs_fn=reference_inputs_resized_crop_mask, reference_inputs_fn=reference_inputs_resized_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.resized_crop_video, F.resized_crop_video,
...@@ -1144,7 +1145,7 @@ KERNEL_INFOS.extend( ...@@ -1144,7 +1145,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_image_tensor, sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor, reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
...@@ -1166,7 +1167,7 @@ KERNEL_INFOS.extend( ...@@ -1166,7 +1167,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_pad_mask, sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F.pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask, reference_inputs_fn=reference_inputs_pad_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.pad_video, F.pad_video,
...@@ -1225,7 +1226,7 @@ KERNEL_INFOS.extend( ...@@ -1225,7 +1226,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_image_tensor, sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.perspective_bounding_box, F.perspective_bounding_box,
...@@ -1236,7 +1237,7 @@ KERNEL_INFOS.extend( ...@@ -1236,7 +1237,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_mask, sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask, reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.perspective_video, F.perspective_video,
...@@ -1306,7 +1307,7 @@ KERNEL_INFOS.extend( ...@@ -1306,7 +1307,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_image_tensor, sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.elastic_bounding_box, F.elastic_bounding_box,
...@@ -1317,7 +1318,7 @@ KERNEL_INFOS.extend( ...@@ -1317,7 +1318,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_elastic_mask, sample_inputs_fn=sample_inputs_elastic_mask,
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask, reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.elastic_video, F.elastic_video,
...@@ -1387,7 +1388,7 @@ KERNEL_INFOS.extend( ...@@ -1387,7 +1388,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_image_tensor, sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor, reference_inputs_fn=reference_inputs_center_crop_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -1404,7 +1405,7 @@ KERNEL_INFOS.extend( ...@@ -1404,7 +1405,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_center_crop_mask, sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask, reference_inputs_fn=reference_inputs_center_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -1441,7 +1442,7 @@ KERNEL_INFOS.extend( ...@@ -1441,7 +1442,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.gaussian_blur_image_tensor, F.gaussian_blur_image_tensor,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
...@@ -1529,7 +1530,7 @@ KERNEL_INFOS.extend( ...@@ -1529,7 +1530,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_equalize_image_tensor, sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil), reference_fn=pil_reference_wrapper(F.equalize_image_pil),
reference_inputs_fn=reference_inputs_equalize_image_tensor, reference_inputs_fn=reference_inputs_equalize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.equalize_video, F.equalize_video,
...@@ -1566,7 +1567,7 @@ KERNEL_INFOS.extend( ...@@ -1566,7 +1567,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_invert_image_tensor, sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil), reference_fn=pil_reference_wrapper(F.invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor, reference_inputs_fn=reference_inputs_invert_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.invert_video, F.invert_video,
...@@ -1607,7 +1608,7 @@ KERNEL_INFOS.extend( ...@@ -1607,7 +1608,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_posterize_image_tensor, sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil), reference_fn=pil_reference_wrapper(F.posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor, reference_inputs_fn=reference_inputs_posterize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.posterize_video, F.posterize_video,
...@@ -1651,7 +1652,7 @@ KERNEL_INFOS.extend( ...@@ -1651,7 +1652,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_solarize_image_tensor, sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil), reference_fn=pil_reference_wrapper(F.solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor, reference_inputs_fn=reference_inputs_solarize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.solarize_video, F.solarize_video,
...@@ -1688,7 +1689,7 @@ KERNEL_INFOS.extend( ...@@ -1688,7 +1689,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_autocontrast_image_tensor, sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor, reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.autocontrast_video, F.autocontrast_video,
...@@ -1729,7 +1730,7 @@ KERNEL_INFOS.extend( ...@@ -1729,7 +1730,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.adjust_sharpness_video, F.adjust_sharpness_video,
...@@ -1800,7 +1801,7 @@ KERNEL_INFOS.extend( ...@@ -1800,7 +1801,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.adjust_brightness_video, F.adjust_brightness_video,
...@@ -1841,7 +1842,7 @@ KERNEL_INFOS.extend( ...@@ -1841,7 +1842,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.adjust_contrast_video, F.adjust_contrast_video,
...@@ -1886,7 +1887,7 @@ KERNEL_INFOS.extend( ...@@ -1886,7 +1887,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.adjust_gamma_video, F.adjust_gamma_video,
...@@ -1927,7 +1928,7 @@ KERNEL_INFOS.extend( ...@@ -1927,7 +1928,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.adjust_hue_video, F.adjust_hue_video,
...@@ -1967,7 +1968,7 @@ KERNEL_INFOS.extend( ...@@ -1967,7 +1968,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.adjust_saturation_video, F.adjust_saturation_video,
...@@ -2061,7 +2062,7 @@ KERNEL_INFOS.extend( ...@@ -2061,7 +2062,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.five_crop_image_pil), reference_fn=pil_reference_wrapper(F.five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor, reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.five_crop_video, F.five_crop_video,
...@@ -2074,7 +2075,7 @@ KERNEL_INFOS.extend( ...@@ -2074,7 +2075,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), reference_fn=pil_reference_wrapper(F.ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor, reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,
), ),
KernelInfo( KernelInfo(
F.ten_crop_video, F.ten_crop_video,
......
...@@ -93,6 +93,13 @@ def fix_rng_seed(): ...@@ -93,6 +93,13 @@ def fix_rng_seed():
yield yield
@pytest.fixture()
def test_id(request):
test_class_name = request.cls.__name__ if request.cls is not None else None
test_function_name = request.node.originalname
return test_class_name, test_function_name
class TestKernels: class TestKernels:
sample_inputs = make_info_args_kwargs_parametrization( sample_inputs = make_info_args_kwargs_parametrization(
KERNEL_INFOS, KERNEL_INFOS,
...@@ -107,16 +114,20 @@ class TestKernels: ...@@ -107,16 +114,20 @@ class TestKernels:
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_scripted_vs_eager(self, info, args_kwargs, device): def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
kernel_eager = info.kernel kernel_eager = info.kernel
kernel_scripted = script(kernel_eager) kernel_scripted = script(kernel_eager)
args, kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
actual = kernel_scripted(*args, **kwargs) actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(*args, **kwargs) expected = kernel_eager(input, *other_args, **kwargs)
assert_close(actual, expected, **info.closeness_kwargs) assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
)
def _unbatch(self, batch, *, data_dims): def _unbatch(self, batch, *, data_dims):
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
...@@ -137,7 +148,7 @@ class TestKernels: ...@@ -137,7 +148,7 @@ class TestKernels:
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_batched_vs_single(self, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input) feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
...@@ -168,7 +179,11 @@ class TestKernels: ...@@ -168,7 +179,11 @@ class TestKernels:
single_inputs = self._unbatch(batched_input, data_dims=data_dims) single_inputs = self._unbatch(batched_input, data_dims=data_dims)
expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs) expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
assert_close(actual, expected, **info.closeness_kwargs) assert_close(
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
)
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -185,14 +200,19 @@ class TestKernels: ...@@ -185,14 +200,19 @@ class TestKernels:
@sample_inputs @sample_inputs
@needs_cuda @needs_cuda
def test_cuda_vs_cpu(self, info, args_kwargs): def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu") (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
input_cuda = input_cpu.to("cuda") input_cuda = input_cpu.to("cuda")
output_cpu = info.kernel(input_cpu, *other_args, **kwargs) output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
output_cuda = info.kernel(input_cuda, *other_args, **kwargs) output_cuda = info.kernel(input_cuda, *other_args, **kwargs)
assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs) assert_close(
output_cuda,
output_cpu,
check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
)
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -208,13 +228,18 @@ class TestKernels: ...@@ -208,13 +228,18 @@ class TestKernels:
assert output.device == input.device assert output.device == input.device
@reference_inputs @reference_inputs
def test_against_reference(self, info, args_kwargs): def test_against_reference(self, test_id, info, args_kwargs):
args, kwargs = args_kwargs.load("cpu") (input, *other_args), kwargs = args_kwargs.load("cpu")
actual = info.kernel(*args, **kwargs) actual = info.kernel(input, *other_args, **kwargs)
expected = info.reference_fn(*args, **kwargs) expected = info.reference_fn(input, *other_args, **kwargs)
assert_close(actual, expected, check_dtype=False, **info.closeness_kwargs) assert_close(
actual,
expected,
check_dtype=False,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
)
@pytest.fixture @pytest.fixture
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment