Unverified Commit b03be78a authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_inference_instance_segmentation_head` (#17872)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 494aac65
......@@ -387,9 +387,12 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
)
expected_slice = torch.tensor(
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
).to(torch_device)
expected_slice = [
[-1.3737124, -1.7724937, -1.9364233],
[-1.5977281, -1.9867939, -2.1523695],
[-1.5795398, -1.9269832, -2.093942],
]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment