Unverified Commit 26a06814 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix SAM tests and use smaller checkpoints (#23656)

* Fix SAM tests and use smaller checkpoints

* Override test_model_from_pretrained to use sam-vit-base as well

* make fixup
parent 6f72e71f
...@@ -436,8 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -436,8 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self): def test_hidden_states_output(self):
pass pass
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) # Use a slightly higher default tol to make the tests non-flaky
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -461,8 +462,8 @@ def prepare_dog_img(): ...@@ -461,8 +462,8 @@ def prepare_dog_img():
@slow @slow
class SamModelIntegrationTest(unittest.TestCase): class SamModelIntegrationTest(unittest.TestCase):
def test_inference_mask_generation_no_point(self): def test_inference_mask_generation_no_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -474,13 +475,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -474,13 +475,12 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4)) self.assertTrue(torch.allclose(masks, torch.tensor([-4.1807, -3.4949, -3.4483]).to(torch_device), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self): def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -497,15 +497,14 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -497,15 +497,14 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue( self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4) torch.allclose(masks, torch.tensor([-12.7657, -12.3683, -12.5985]).to(torch_device), atol=2e-4)
) )
def test_inference_mask_generation_batched_points_batched_images(self): def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -528,26 +527,26 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -528,26 +527,26 @@ class SamModelIntegrationTest(unittest.TestCase):
EXPECTED_SCORES = torch.tensor( EXPECTED_SCORES = torch.tensor(
[ [
[ [
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
], ],
[ [
[0.8405, 0.6292, 0.3840], [0.3317, 0.7264, 0.7646],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
], ],
] ]
) )
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406]) EXPECTED_MASKS = torch.tensor([-2.8552, -2.7990, -2.9612])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
def test_inference_mask_generation_one_point_one_bb_zero(self): def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -569,11 +568,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -569,11 +568,11 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9689), atol=1e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7892), atol=1e-4))
def test_inference_mask_generation_one_point(self): def test_inference_mask_generation_one_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -590,8 +589,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -590,8 +589,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4))
# With no label # With no label
input_points = [[[400, 650]]] input_points = [[[400, 650]]]
...@@ -601,12 +599,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -601,12 +599,11 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9712), atol=1e-4))
def test_inference_mask_generation_two_points(self): def test_inference_mask_generation_two_points(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -623,8 +620,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -623,8 +620,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4))
# no labels # no labels
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
...@@ -633,11 +629,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -633,11 +629,11 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9936), atol=1e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
def test_inference_mask_generation_two_points_batched(self): def test_inference_mask_generation_two_points_batched(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -654,13 +650,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -654,13 +650,12 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9936), atol=1e-4)) self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))
self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9716), atol=1e-4))
def test_inference_mask_generation_one_box(self): def test_inference_mask_generation_one_box(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -674,12 +669,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -674,12 +669,11 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8686), atol=1e-4))
def test_inference_mask_generation_batched_image_one_point(self): def test_inference_mask_generation_batched_image_one_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -707,8 +701,8 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -707,8 +701,8 @@ class SamModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
def test_inference_mask_generation_two_points_point_batch(self): def test_inference_mask_generation_two_points_point_batch(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -729,12 +723,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -729,12 +723,12 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu() iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 2, 3)) self.assertTrue(iou_scores.shape == (1, 2, 3))
torch.testing.assert_allclose( torch.testing.assert_allclose(
iou_scores, torch.tensor([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), atol=1e-4, rtol=1e-4 iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
) )
def test_inference_mask_generation_three_boxes_point_batch(self): def test_inference_mask_generation_three_boxes_point_batch(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -743,7 +737,9 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -743,7 +737,9 @@ class SamModelIntegrationTest(unittest.TestCase):
# fmt: off # fmt: off
input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu()
EXPECTED_IOU = torch.tensor([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]]) EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
[0.5996, 0.7661, 0.7937],
[0.5996, 0.7661, 0.7937]]])
# fmt: on # fmt: on
input_boxes = input_boxes.unsqueeze(0) input_boxes = input_boxes.unsqueeze(0)
......
...@@ -34,7 +34,6 @@ if is_tf_available(): ...@@ -34,7 +34,6 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import SamProcessor, TFSamModel from transformers import SamProcessor, TFSamModel
from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
...@@ -400,9 +399,8 @@ class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -400,9 +399,8 @@ class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = TFSamModel.from_pretrained("facebook/sam-vit-base") # sam-vit-huge blows out our memory
model = TFSamModel.from_pretrained(model_name) self.assertIsNotNone(model)
self.assertIsNotNone(model)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None): def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None):
super().check_pt_tf_outputs( super().check_pt_tf_outputs(
...@@ -430,8 +428,8 @@ def prepare_dog_img(): ...@@ -430,8 +428,8 @@ def prepare_dog_img():
@slow @slow
class SamModelIntegrationTest(unittest.TestCase): class SamModelIntegrationTest(unittest.TestCase):
def test_inference_mask_generation_no_point(self): def test_inference_mask_generation_no_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
inputs = processor(images=raw_image, return_tensors="tf") inputs = processor(images=raw_image, return_tensors="tf")
...@@ -439,13 +437,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -439,13 +437,12 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.4515), atol=2e-4))
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4)) self.assertTrue(np.allclose(masks.numpy(), np.array([-4.1807, -3.4949, -3.4483]), atol=1e-2))
self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2))
def test_inference_mask_generation_one_point_one_bb(self): def test_inference_mask_generation_one_point_one_bb(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
input_boxes = [[[650, 900, 1000, 1250]]] input_boxes = [[[650, 900, 1000, 1250]]]
...@@ -457,12 +454,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -457,12 +454,12 @@ class SamModelIntegrationTest(unittest.TestCase):
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4)) self.assertTrue(np.allclose(scores[-1], np.array(0.9566), atol=2e-4))
self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2)) self.assertTrue(np.allclose(masks.numpy(), np.array([-12.7657, -12.3683, -12.5985]), atol=2e-2))
def test_inference_mask_generation_batched_points_batched_images(self): def test_inference_mask_generation_batched_points_batched_images(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
input_points = [ input_points = [
...@@ -479,26 +476,26 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -479,26 +476,26 @@ class SamModelIntegrationTest(unittest.TestCase):
EXPECTED_SCORES = np.array( EXPECTED_SCORES = np.array(
[ [
[ [
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
], ],
[ [
[0.8405, 0.6292, 0.3840], [0.3317, 0.7264, 0.7646],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
[0.9673, 0.9441, 0.9084], [0.6765, 0.9379, 0.8803],
], ],
] ]
) )
EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406]) EXPECTED_MASKS = np.array([-2.8552, -2.7990, -2.9612])
self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3)) self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3))
self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2)) self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2))
def test_inference_mask_generation_one_point_one_bb_zero(self): def test_inference_mask_generation_one_point_one_bb_zero(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
input_boxes = [[[620, 900, 1000, 1255]]] input_boxes = [[[620, 900, 1000, 1255]]]
...@@ -515,12 +512,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -515,12 +512,11 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.7894), atol=1e-4))
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4))
def test_inference_mask_generation_one_point(self): def test_inference_mask_generation_one_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
...@@ -532,7 +528,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -532,7 +528,7 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4)) self.assertTrue(np.allclose(scores[-1], np.array(0.9675), atol=1e-4))
# With no label # With no label
input_points = [[[400, 650]]] input_points = [[[400, 650]]]
...@@ -542,11 +538,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -542,11 +538,11 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4)) self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9675), atol=1e-4))
def test_inference_mask_generation_two_points(self): def test_inference_mask_generation_two_points(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
input_points = [[[400, 650], [800, 650]]] input_points = [[[400, 650], [800, 650]]]
...@@ -557,7 +553,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -557,7 +553,7 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9762), atol=1e-4))
# no labels # no labels
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf") inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
...@@ -565,11 +561,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -565,11 +561,11 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9762), atol=1e-4))
def test_inference_mask_generation_two_points_batched(self): def test_inference_mask_generation_two_points_batched(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
...@@ -583,12 +579,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -583,12 +579,12 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4)) self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9762), atol=1e-4))
self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4)) self.assertTrue(np.allclose(scores[1][-1], np.array(0.9637), atol=1e-4))
def test_inference_mask_generation_one_box(self): def test_inference_mask_generation_one_box(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
...@@ -599,11 +595,11 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -599,11 +595,11 @@ class SamModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores) scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4)) self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.7937), atol=1e-4))
def test_inference_mask_generation_batched_image_one_point(self): def test_inference_mask_generation_batched_image_one_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
raw_dog_image = prepare_dog_img() raw_dog_image = prepare_dog_img()
...@@ -624,8 +620,8 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -624,8 +620,8 @@ class SamModelIntegrationTest(unittest.TestCase):
self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4)) self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4))
def test_inference_mask_generation_two_points_point_batch(self): def test_inference_mask_generation_two_points_point_batch(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
...@@ -644,21 +640,23 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -644,21 +640,23 @@ class SamModelIntegrationTest(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
iou_scores.numpy(), iou_scores.numpy(),
np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), np.array([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]),
atol=1e-4, atol=1e-4,
rtol=1e-4, rtol=1e-4,
) )
) )
def test_inference_mask_generation_three_boxes_point_batch(self): def test_inference_mask_generation_three_boxes_point_batch(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge") model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
raw_image = prepare_image() raw_image = prepare_image()
# fmt: off # fmt: off
input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]) input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]])
EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]]) EXPECTED_IOU = np.array([[[0.9773, 0.9881, 0.9522],
[0.5996, 0.7661, 0.7937],
[0.5996, 0.7661, 0.7937]]])
# fmt: on # fmt: on
input_boxes = tf.expand_dims(input_boxes, 0) input_boxes = tf.expand_dims(input_boxes, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment