"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "513fa30a636642ccc1d93f3e6a48d612d08dbce8"
Unverified Commit abf691aa authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix PyTorch SAM tests (#23682)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent b687af0b
...@@ -476,7 +476,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -476,7 +476,7 @@ class SamModelIntegrationTest(unittest.TestCase):
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-4.1807, -3.4949, -3.4483]).to(torch_device), atol=2e-4)) self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self): def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-base") model = SamModel.from_pretrained("facebook/sam-vit-base")
...@@ -499,7 +499,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -499,7 +499,7 @@ class SamModelIntegrationTest(unittest.TestCase):
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
self.assertTrue( self.assertTrue(
torch.allclose(masks, torch.tensor([-12.7657, -12.3683, -12.5985]).to(torch_device), atol=2e-4) torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
) )
def test_inference_mask_generation_batched_points_batched_images(self): def test_inference_mask_generation_batched_points_batched_images(self):
...@@ -540,7 +540,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -540,7 +540,7 @@ class SamModelIntegrationTest(unittest.TestCase):
], ],
] ]
) )
EXPECTED_MASKS = torch.tensor([-2.8552, -2.7990, -2.9612]) EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
...@@ -568,7 +568,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -568,7 +568,7 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7892), atol=1e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))
def test_inference_mask_generation_one_point(self): def test_inference_mask_generation_one_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-base") model = SamModel.from_pretrained("facebook/sam-vit-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment