Unverified Commit 924d373c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix flaky test for rotate_bounding_box (#7362)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent feda8b7b
...@@ -351,7 +351,7 @@ assert_equal = functools.partial(assert_close, rtol=0, atol=0) ...@@ -351,7 +351,7 @@ assert_equal = functools.partial(assert_close, rtol=0, atol=0)
def parametrized_error_message(*args, **kwargs): def parametrized_error_message(*args, **kwargs):
def to_str(obj): def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10: if isinstance(obj, torch.Tensor) and obj.numel() > 30:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})" return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
elif isinstance(obj, enum.Enum): elif isinstance(obj, enum.Enum):
return f"{type(obj).__name__}.{obj.name}" return f"{type(obj).__name__}.{obj.name}"
......
...@@ -146,7 +146,7 @@ class TestKernels: ...@@ -146,7 +146,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs), msg=parametrized_error_message(input, other_args, **kwargs),
) )
def _unbatch(self, batch, *, data_dims): def _unbatch(self, batch, *, data_dims):
...@@ -204,7 +204,7 @@ class TestKernels: ...@@ -204,7 +204,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device), **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, **kwargs), msg=parametrized_error_message(batched_input, *other_args, **kwargs),
) )
@sample_inputs @sample_inputs
...@@ -236,7 +236,7 @@ class TestKernels: ...@@ -236,7 +236,7 @@ class TestKernels:
output_cpu, output_cpu,
check_device=False, check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device), **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, **kwargs), msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
) )
@sample_inputs @sample_inputs
...@@ -294,7 +294,7 @@ class TestKernels: ...@@ -294,7 +294,7 @@ class TestKernels:
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device), **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs), msg=parametrized_error_message(input, *other_args, **kwargs),
) )
......
...@@ -860,8 +860,8 @@ KERNEL_INFOS.extend( ...@@ -860,8 +860,8 @@ KERNEL_INFOS.extend(
reference_fn=reference_rotate_bounding_box, reference_fn=reference_rotate_bounding_box,
reference_inputs_fn=reference_inputs_rotate_bounding_box, reference_inputs_fn=reference_inputs_rotate_bounding_box,
closeness_kwargs={ closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6), **scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4),
}, },
), ),
KernelInfo( KernelInfo(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment