"docs/vscode:/vscode.git/clone" did not exist on "1ccbfbb663399d3aa363af513e5a8352f3afdb35"
Unverified Commit a2151b96 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace assert torch.allclose with torch.testing.assert_allclose (#6895)

parent 79ca506c
...@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase): ...@@ -20,7 +20,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = partition(x, partition_size) x_hat = partition(x, partition_size)
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions) x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
assert torch.allclose(x, x_hat) torch.testing.assert_close(x, x_hat)
def test_maxvit_grid_partition(self): def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224) input_shape = (1, 3, 224, 224)
...@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase): ...@@ -39,7 +39,7 @@ class MaxvitTester(unittest.TestCase):
x_hat = post_swap(x_hat) x_hat = post_swap(x_hat)
x_hat = departition(x_hat, n_partitions, partition_size, partition_size) x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
assert torch.allclose(x, x_hat) torch.testing.assert_close(x, x_hat)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -630,7 +630,7 @@ class TestNMS: ...@@ -630,7 +630,7 @@ class TestNMS:
boxes, scores = self._create_tensors_with_iou(1000, iou) boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou) keep_ref = self._reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou)
assert torch.allclose(keep, keep_ref), err_msg.format(iou) torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
def test_nms_input_errors(self): def test_nms_input_errors(self):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
...@@ -661,7 +661,7 @@ class TestNMS: ...@@ -661,7 +661,7 @@ class TestNMS:
keep = ops.nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou)
qkeep = ops.nms(qboxes, qscores, iou) qkeep = ops.nms(qboxes, qscores, iou)
assert torch.allclose(qkeep, keep), err_msg.format(iou) torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
@needs_cuda @needs_cuda
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
...@@ -1237,7 +1237,7 @@ class TestIouBase: ...@@ -1237,7 +1237,7 @@ class TestIouBase:
boxes2 = gen_box(7) boxes2 = gen_box(7)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2) b = target_fn(boxes1, boxes2)
assert torch.allclose(a, b) torch.testing.assert_close(a, b)
class TestBoxIou(TestIouBase): class TestBoxIou(TestIouBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment