"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "8f4594074be2b2f8c868f311f24a030806a3deb5"
Unverified Commit 791bc844 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

retain input dtype in all prototype transforms kernels (#6597)

* fix convert_format_bounding_box, affine_bounding_box, and resize_bounding_box

* fix perspective and elastic

* fix tests
parent 35b0b9ee
...@@ -303,8 +303,7 @@ def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, tr ...@@ -303,8 +303,7 @@ def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, tr
np.max(transformed_points[:, 0]), np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
], ],
# FIXME: re-add this as soon as the kernel is fixed to also retain the dtype dtype=bbox.dtype,
# dtype=bbox.dtype,
) )
return F.convert_format_bounding_box( return F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
...@@ -369,6 +368,7 @@ KERNEL_INFOS.extend( ...@@ -369,6 +368,7 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_bounding_box, sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box, reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs=dict(atol=1, rtol=0),
), ),
KernelInfo( KernelInfo(
F.affine_mask, F.affine_mask,
......
...@@ -480,21 +480,6 @@ def test_eager_vs_scripted(functional_info, sample_input): ...@@ -480,21 +480,6 @@ def test_eager_vs_scripted(functional_info, sample_input):
functional_info, functional_info,
sample_input, sample_input,
id=f"{functional_info.name}-{idx}", id=f"{functional_info.name}-{idx}",
marks=[
*(
[pytest.mark.xfail(strict=False)]
if functional_info.name
in {
"rotate_bounding_box",
"crop_bounding_box",
"resized_crop_bounding_box",
"perspective_bounding_box",
"elastic_bounding_box",
"center_crop_bounding_box",
}
else []
)
],
) )
for functional_info in FUNCTIONAL_INFOS for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs()) for idx, sample_input in enumerate(functional_info.sample_inputs())
...@@ -646,7 +631,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -646,7 +631,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
image_size=image_size, image_size=image_size,
dtype=torch.float32, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box( return convert_format_bounding_box(
...@@ -683,7 +668,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -683,7 +668,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes) torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -1069,7 +1054,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -1069,7 +1054,7 @@ def test_correctness_pad_bounding_box(device, padding):
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -1188,7 +1173,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1188,7 +1173,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
image_size=bbox.image_size, image_size=bbox.image_size,
dtype=torch.float32, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box( return convert_format_bounding_box(
...@@ -1222,7 +1207,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1222,7 +1207,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=1e-5, atol=1e-5) torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
......
...@@ -186,30 +186,7 @@ class TestCommon: ...@@ -186,30 +186,7 @@ class TestCommon:
assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs) assert_close(output_cuda, output_cpu, check_device=False, **info.closeness_kwargs)
@pytest.mark.parametrize( @sample_inputs
("info", "args_kwargs"),
[
pytest.param(
info,
args_kwargs,
id=f"{info.kernel_name}-",
marks=[
*(
[pytest.mark.xfail(strict=False)]
if info.kernel_name
in {
"resize_bounding_box",
"affine_bounding_box",
"convert_format_bounding_box",
}
else []
)
],
)
for info in KERNEL_INFOS
for args_kwargs in info.sample_inputs_fn()
],
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_consistency(self, info, args_kwargs, device): def test_dtype_and_device_consistency(self, info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
......
...@@ -153,7 +153,7 @@ def resize_bounding_box( ...@@ -153,7 +153,7 @@ def resize_bounding_box(
old_height, old_width = image_size old_height, old_width = image_size
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) return bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape)
def resize( def resize(
...@@ -348,7 +348,7 @@ def _affine_bounding_box_xyxy( ...@@ -348,7 +348,7 @@ def _affine_bounding_box_xyxy(
out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0] out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1] out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]
return out_bboxes return out_bboxes.to(bounding_box.dtype)
def affine_bounding_box( def affine_bounding_box(
...@@ -829,7 +829,7 @@ def perspective_bounding_box( ...@@ -829,7 +829,7 @@ def perspective_bounding_box(
transformed_points = transformed_points.view(-1, 4, 2) transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
...@@ -929,7 +929,7 @@ def elastic_bounding_box( ...@@ -929,7 +929,7 @@ def elastic_bounding_box(
transformed_points = transformed_points.view(-1, 4, 2) transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
......
...@@ -67,7 +67,7 @@ def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: ...@@ -67,7 +67,7 @@ def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
y1 = cy - 0.5 * h y1 = cy - 0.5 * h
x2 = cx + 0.5 * w x2 = cx + 0.5 * w
y2 = cy + 0.5 * h y2 = cy + 0.5 * h
return torch.stack((x1, y1, x2, y2), dim=-1) return torch.stack((x1, y1, x2, y2), dim=-1).to(cxcywh.dtype)
def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
...@@ -76,7 +76,7 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: ...@@ -76,7 +76,7 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
cy = (y1 + y2) / 2 cy = (y1 + y2) / 2
w = x2 - x1 w = x2 - x1
h = y2 - y1 h = y2 - y1
return torch.stack((cx, cy, w, h), dim=-1) return torch.stack((cx, cy, w, h), dim=-1).to(xyxy.dtype)
def convert_format_bounding_box( def convert_format_bounding_box(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment