Unverified Commit a2151b96 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace assert torch.allclose with torch.testing.assert_allclose (#6895)

parent 79ca506c
......@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = partition(x, partition_size)
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
assert torch.allclose(x, x_hat)
torch.testing.assert_close(x, x_hat)
def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224)
......@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = post_swap(x_hat)
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
assert torch.allclose(x, x_hat)
torch.testing.assert_close(x, x_hat)
if __name__ == "__main__":
......
......@@ -630,7 +630,7 @@ class TestNMS:
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
assert torch.allclose(keep, keep_ref), err_msg.format(iou)
torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
def test_nms_input_errors(self):
with pytest.raises(RuntimeError):
......@@ -661,7 +661,7 @@ class TestNMS:
keep = ops.nms(boxes, scores, iou)
qkeep = ops.nms(qboxes, qscores, iou)
assert torch.allclose(qkeep, keep), err_msg.format(iou)
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
@needs_cuda
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
......@@ -1237,7 +1237,7 @@ class TestIouBase:
boxes2 = gen_box(7)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
assert torch.allclose(a, b)
torch.testing.assert_close(a, b)
class TestBoxIou(TestIouBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment