Unverified Commit 55150bfb authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_ops.py (#3883)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent ea34cd1e
from common_utils import needs_cuda, cpu_only
from _assert_utils import assert_equal
import math
import unittest
import pytest
......@@ -78,7 +79,7 @@ class RoIOpTester(OpTester):
sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs)
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
def _test_backward(self, device, contiguous):
pool_size = 2
......@@ -363,7 +364,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
t_scale = torch.full_like(abs_diff, fill_value=scale)
self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5))
torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)
x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
......@@ -555,7 +556,7 @@ class TestNMS:
iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
assert torch.all(torch.eq(keep32, keep16))
assert_equal(keep32, keep16)
@cpu_only
def test_batched_nms_implementations(self):
......@@ -573,12 +574,13 @@ class TestNMS:
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
err_msg = "The vanilla and the trick implementation yield different nms outputs."
assert torch.allclose(keep_vanilla, keep_trick), err_msg
torch.testing.assert_close(
keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
)
# Also make sure an empty tensor is returned if boxes is empty
empty = torch.empty((0,), dtype=torch.int64)
assert torch.allclose(empty, ops.batched_nms(empty, None, None, None))
torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
class DeformConvTester(OpTester, unittest.TestCase):
......@@ -690,15 +692,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
)
# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
)
# test for wrong sizes
with self.assertRaises(RuntimeError):
......@@ -778,7 +782,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
else:
self.assertTrue(init_weight.grad is not None)
res_grads = init_weight.grad.to("cpu")
self.assertTrue(true_cpu_grads.allclose(res_grads))
torch.testing.assert_close(true_cpu_grads, res_grads)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
......@@ -812,14 +816,14 @@ class FrozenBNTester(unittest.TestCase):
bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
# Check computation for eps > 0
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
bn.load_state_dict(state_dict)
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
def test_frozenbatchnorm2d_n_arg(self):
"""Ensure a warning is thrown when passing `n` kwarg
......@@ -860,20 +864,10 @@ class BoxTester(unittest.TestCase):
exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
box_same = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
box_same = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
box_same = ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
assert exp_xyxy.size() == torch.Size([4, 4])
assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
def test_bbox_xyxy_xywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same.
......@@ -883,16 +877,13 @@ class BoxTester(unittest.TestCase):
exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
assert exp_xywh.size() == torch.Size([4, 4])
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
self.assertEqual(exp_xywh.size(), torch.Size([4, 4]))
self.assertEqual(exp_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, exp_xywh)).item()
assert_equal(box_xywh, exp_xywh)
# Reverse conversion
box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
self.assertEqual(box_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
assert_equal(box_xyxy, box_tensor)
def test_bbox_xyxy_cxcywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same.
......@@ -902,16 +893,13 @@ class BoxTester(unittest.TestCase):
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)
assert exp_cxcywh.size() == torch.Size([4, 4])
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4]))
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
assert_equal(box_cxcywh, exp_cxcywh)
# Reverse conversion
box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
self.assertEqual(box_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
assert_equal(box_xyxy, box_tensor)
def test_bbox_xywh_cxcywh(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
......@@ -921,16 +909,13 @@ class BoxTester(unittest.TestCase):
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)
assert exp_cxcywh.size() == torch.Size([4, 4])
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4]))
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
assert_equal(box_cxcywh, exp_cxcywh)
# Reverse conversion
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
self.assertEqual(box_xywh.size(), torch.Size([4, 4]))
self.assertEqual(box_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, box_tensor)).item()
assert_equal(box_xywh, box_tensor)
def test_bbox_invalid(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
......@@ -951,19 +936,18 @@ class BoxTester(unittest.TestCase):
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE)
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)
class BoxAreaTester(unittest.TestCase):
def test_box_area(self):
def area_check(box, expected, tolerance=1e-4):
out = ops.box_area(box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
# Check for int boxes
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
......@@ -991,8 +975,7 @@ class BoxIouTester(unittest.TestCase):
def test_iou(self):
def iou_check(box, expected, tolerance=1e-4):
out = ops.box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
......@@ -1013,8 +996,7 @@ class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self):
def gen_iou_check(box, expected, tolerance=1e-4):
out = ops.generalized_box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment