Unverified Commit 924d373c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix flaky test for rotate_bounding_box (#7362)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent feda8b7b
......@@ -351,7 +351,7 @@ assert_equal = functools.partial(assert_close, rtol=0, atol=0)
def parametrized_error_message(*args, **kwargs):
def to_str(obj):
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
if isinstance(obj, torch.Tensor) and obj.numel() > 30:
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
elif isinstance(obj, enum.Enum):
return f"{type(obj).__name__}.{obj.name}"
......
......@@ -146,7 +146,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
msg=parametrized_error_message(input, other_args, **kwargs),
)
def _unbatch(self, batch, *, data_dims):
......@@ -204,7 +204,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(batched_input, *other_args, **kwargs),
)
@sample_inputs
......@@ -236,7 +236,7 @@ class TestKernels:
output_cpu,
check_device=False,
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
)
@sample_inputs
......@@ -294,7 +294,7 @@ class TestKernels:
actual,
expected,
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs),
msg=parametrized_error_message(input, *other_args, **kwargs),
)
......
......@@ -860,8 +860,8 @@ KERNEL_INFOS.extend(
reference_fn=reference_rotate_bounding_box,
reference_inputs_fn=reference_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4),
},
),
KernelInfo(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment