Unverified Commit 55150bfb authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_ops.py (#3883)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent ea34cd1e
from common_utils import needs_cuda, cpu_only from common_utils import needs_cuda, cpu_only
from _assert_utils import assert_equal
import math import math
import unittest import unittest
import pytest import pytest
...@@ -78,7 +79,7 @@ class RoIOpTester(OpTester): ...@@ -78,7 +79,7 @@ class RoIOpTester(OpTester):
sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs) sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs)
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol)) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
def _test_backward(self, device, contiguous): def _test_backward(self, device, contiguous):
pool_size = 2 pool_size = 2
...@@ -363,7 +364,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase): ...@@ -363,7 +364,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize()) abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
t_scale = torch.full_like(abs_diff, fill_value=scale) t_scale = torch.full_like(abs_diff, fill_value=scale)
self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5)) torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)
x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype) x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8) qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
...@@ -555,7 +556,7 @@ class TestNMS: ...@@ -555,7 +556,7 @@ class TestNMS:
iou_thres = 0.2 iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres) keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
assert torch.all(torch.eq(keep32, keep16)) assert_equal(keep32, keep16)
@cpu_only @cpu_only
def test_batched_nms_implementations(self): def test_batched_nms_implementations(self):
...@@ -573,12 +574,13 @@ class TestNMS: ...@@ -573,12 +574,13 @@ class TestNMS:
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold) keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
err_msg = "The vanilla and the trick implementation yield different nms outputs." torch.testing.assert_close(
assert torch.allclose(keep_vanilla, keep_trick), err_msg keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
)
# Also make sure an empty tensor is returned if boxes is empty # Also make sure an empty tensor is returned if boxes is empty
empty = torch.empty((0,), dtype=torch.int64) empty = torch.empty((0,), dtype=torch.int64)
assert torch.allclose(empty, ops.batched_nms(empty, None, None, None)) torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
class DeformConvTester(OpTester, unittest.TestCase): class DeformConvTester(OpTester, unittest.TestCase):
...@@ -690,15 +692,17 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -690,15 +692,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
bias = layer.bias.data bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), torch.testing.assert_close(
'\nres:\n{}\nexpected:\n{}'.format(res, expected)) res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
)
# no modulation test # no modulation test
res = layer(x, offset) res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), torch.testing.assert_close(
'\nres:\n{}\nexpected:\n{}'.format(res, expected)) res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
)
# test for wrong sizes # test for wrong sizes
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
...@@ -778,7 +782,7 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -778,7 +782,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
else: else:
self.assertTrue(init_weight.grad is not None) self.assertTrue(init_weight.grad is not None)
res_grads = init_weight.grad.to("cpu") res_grads = init_weight.grad.to("cpu")
self.assertTrue(true_cpu_grads.allclose(res_grads)) torch.testing.assert_close(true_cpu_grads, res_grads)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self): def test_autocast(self):
...@@ -812,14 +816,14 @@ class FrozenBNTester(unittest.TestCase): ...@@ -812,14 +816,14 @@ class FrozenBNTester(unittest.TestCase):
bn = torch.nn.BatchNorm2d(sample_size[1]).eval() bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
bn.load_state_dict(state_dict) bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range # Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
# Check computation for eps > 0 # Check computation for eps > 0
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5) fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
fbn.load_state_dict(state_dict, strict=False) fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval() bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
bn.load_state_dict(state_dict) bn.load_state_dict(state_dict)
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
def test_frozenbatchnorm2d_n_arg(self): def test_frozenbatchnorm2d_n_arg(self):
"""Ensure a warning is thrown when passing `n` kwarg """Ensure a warning is thrown when passing `n` kwarg
...@@ -860,20 +864,10 @@ class BoxTester(unittest.TestCase): ...@@ -860,20 +864,10 @@ class BoxTester(unittest.TestCase):
exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
box_same = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy") assert exp_xyxy.size() == torch.Size([4, 4])
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4])) assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype) assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
assert torch.all(torch.eq(box_same, exp_xyxy)).item() assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
box_same = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
box_same = ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
def test_bbox_xyxy_xywh(self): def test_bbox_xyxy_xywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same. # Simple test convert boxes to xywh and back. Make sure they are same.
...@@ -883,16 +877,13 @@ class BoxTester(unittest.TestCase): ...@@ -883,16 +877,13 @@ class BoxTester(unittest.TestCase):
exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
assert exp_xywh.size() == torch.Size([4, 4])
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
self.assertEqual(exp_xywh.size(), torch.Size([4, 4])) assert_equal(box_xywh, exp_xywh)
self.assertEqual(exp_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, exp_xywh)).item()
# Reverse conversion # Reverse conversion
box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy") box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
self.assertEqual(box_xyxy.size(), torch.Size([4, 4])) assert_equal(box_xyxy, box_tensor)
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
def test_bbox_xyxy_cxcywh(self): def test_bbox_xyxy_cxcywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same. # Simple test convert boxes to xywh and back. Make sure they are same.
...@@ -902,16 +893,13 @@ class BoxTester(unittest.TestCase): ...@@ -902,16 +893,13 @@ class BoxTester(unittest.TestCase):
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)
assert exp_cxcywh.size() == torch.Size([4, 4])
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4])) assert_equal(box_cxcywh, exp_cxcywh)
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
# Reverse conversion # Reverse conversion
box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy") box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
self.assertEqual(box_xyxy.size(), torch.Size([4, 4])) assert_equal(box_xyxy, box_tensor)
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
def test_bbox_xywh_cxcywh(self): def test_bbox_xywh_cxcywh(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
...@@ -921,16 +909,13 @@ class BoxTester(unittest.TestCase): ...@@ -921,16 +909,13 @@ class BoxTester(unittest.TestCase):
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)
assert exp_cxcywh.size() == torch.Size([4, 4])
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4])) assert_equal(box_cxcywh, exp_cxcywh)
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
# Reverse conversion # Reverse conversion
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh") box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
self.assertEqual(box_xywh.size(), torch.Size([4, 4])) assert_equal(box_xywh, box_tensor)
self.assertEqual(box_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, box_tensor)).item()
def test_bbox_invalid(self): def test_bbox_invalid(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
...@@ -951,19 +936,18 @@ class BoxTester(unittest.TestCase): ...@@ -951,19 +936,18 @@ class BoxTester(unittest.TestCase):
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh') scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE) torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE)
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh') scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE) torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)
class BoxAreaTester(unittest.TestCase): class BoxAreaTester(unittest.TestCase):
def test_box_area(self): def test_box_area(self):
def area_check(box, expected, tolerance=1e-4): def area_check(box, expected, tolerance=1e-4):
out = ops.box_area(box) out = ops.box_area(box)
assert out.size() == expected.size() torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
assert ((out - expected).abs().max() < tolerance).item()
# Check for int boxes # Check for int boxes
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
...@@ -991,8 +975,7 @@ class BoxIouTester(unittest.TestCase): ...@@ -991,8 +975,7 @@ class BoxIouTester(unittest.TestCase):
def test_iou(self): def test_iou(self):
def iou_check(box, expected, tolerance=1e-4): def iou_check(box, expected, tolerance=1e-4):
out = ops.box_iou(box, box) out = ops.box_iou(box, box)
assert out.size() == expected.size() torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
assert ((out - expected).abs().max() < tolerance).item()
# Check for int boxes # Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]: for dtype in [torch.int16, torch.int32, torch.int64]:
...@@ -1013,8 +996,7 @@ class GenBoxIouTester(unittest.TestCase): ...@@ -1013,8 +996,7 @@ class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self): def test_gen_iou(self):
def gen_iou_check(box, expected, tolerance=1e-4): def gen_iou_check(box, expected, tolerance=1e-4):
out = ops.generalized_box_iou(box, box) out = ops.generalized_box_iou(box, box)
assert out.size() == expected.size() torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
assert ((out - expected).abs().max() < tolerance).item()
# Check for int boxes # Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]: for dtype in [torch.int16, torch.int32, torch.int64]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment