Unverified Commit 42baa58f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`SAM`] Fixes pipeline and adds a dummy pipeline test (#23684)

* add a dummy pipeline test

* change test name
parent 71a5ed34
......@@ -934,7 +934,7 @@ def _generate_crop_boxes(
cropped_images, point_grid_per_crop = _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
)
crop_boxes = np.array(crop_boxes)
crop_boxes = crop_boxes.astype(np.float32)
points_per_crop = np.array([point_grid_per_crop])
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
......
......@@ -20,7 +20,7 @@ import unittest
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
......@@ -751,3 +751,9 @@ class SamModelIntegrationTest(unittest.TestCase):
iou_scores = outputs.iou_scores.cpu()
self.assertTrue(iou_scores.shape == (1, 3, 3))
torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
def test_dummy_pipeline_generation(self):
generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
raw_image = prepare_image()
_ = generator(raw_image, points_per_batch=64)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment