test_modeling_sam.py 28.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch SAM model. """


Yih-Dar's avatar
Yih-Dar committed
18
import gc
19
20
21
22
23
import inspect
import unittest

import requests

24
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
25
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
26
27
28
29
from transformers.utils import is_torch_available, is_vision_available

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
30
from ...test_pipeline_mixin import PipelineTesterMixin
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286


if is_torch_available():
    import torch
    from torch import nn

    from transformers import SamModel, SamProcessor
    from transformers.models.sam.modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST


if is_vision_available():
    from PIL import Image


class SamPromptEncoderTester:
    def __init__(
        self,
        hidden_size=32,
        input_image_size=24,
        patch_size=2,
        mask_input_channels=4,
        num_point_embeddings=4,
        hidden_act="gelu",
    ):
        self.hidden_size = hidden_size
        self.input_image_size = input_image_size
        self.patch_size = patch_size
        self.mask_input_channels = mask_input_channels
        self.num_point_embeddings = num_point_embeddings
        self.hidden_act = hidden_act

    def get_config(self):
        return SamPromptEncoderConfig(
            image_size=self.input_image_size,
            patch_size=self.patch_size,
            mask_input_channels=self.mask_input_channels,
            hidden_size=self.hidden_size,
            num_point_embeddings=self.num_point_embeddings,
            hidden_act=self.hidden_act,
        )

    def prepare_config_and_inputs(self):
        dummy_points = floats_tensor([self.batch_size, 3, 2])
        config = self.get_config()

        return config, dummy_points


class SamMaskDecoderTester:
    def __init__(
        self,
        hidden_size=32,
        hidden_act="relu",
        mlp_dim=64,
        num_hidden_layers=2,
        num_attention_heads=4,
        attention_downsample_rate=2,
        num_multimask_outputs=3,
        iou_head_depth=3,
        iou_head_hidden_dim=32,
        layer_norm_eps=1e-6,
    ):
        self.hidden_size = hidden_size
        self.hidden_act = hidden_act
        self.mlp_dim = mlp_dim
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.attention_downsample_rate = attention_downsample_rate
        self.num_multimask_outputs = num_multimask_outputs
        self.iou_head_depth = iou_head_depth
        self.iou_head_hidden_dim = iou_head_hidden_dim
        self.layer_norm_eps = layer_norm_eps

    def get_config(self):
        return SamMaskDecoderConfig(
            hidden_size=self.hidden_size,
            hidden_act=self.hidden_act,
            mlp_dim=self.mlp_dim,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            attention_downsample_rate=self.attention_downsample_rate,
            num_multimask_outputs=self.num_multimask_outputs,
            iou_head_depth=self.iou_head_depth,
            iou_head_hidden_dim=self.iou_head_hidden_dim,
            layer_norm_eps=self.layer_norm_eps,
        )

    def prepare_config_and_inputs(self):
        config = self.get_config()

        dummy_inputs = {
            "image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
        }

        return config, dummy_inputs


class SamModelTester:
    def __init__(
        self,
        parent,
        hidden_size=36,
        intermediate_size=72,
        projection_dim=62,
        output_channels=32,
        num_hidden_layers=2,
        num_attention_heads=4,
        num_channels=3,
        image_size=24,
        patch_size=2,
        hidden_act="gelu",
        layer_norm_eps=1e-06,
        dropout=0.0,
        attention_dropout=0.0,
        initializer_range=0.02,
        initializer_factor=1.0,
        qkv_bias=True,
        mlp_ratio=4.0,
        use_abs_pos=True,
        use_rel_pos=True,
        rel_pos_zero_init=False,
        window_size=14,
        global_attn_indexes=[2, 5, 8, 11],
        num_pos_feats=16,
        mlp_dim=None,
        batch_size=2,
    ):
        self.parent = parent
        self.image_size = image_size
        self.patch_size = patch_size
        self.output_channels = output_channels
        self.num_channels = num_channels
        self.hidden_size = hidden_size
        self.projection_dim = projection_dim
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.initializer_range = initializer_range
        self.initializer_factor = initializer_factor
        self.hidden_act = hidden_act
        self.layer_norm_eps = layer_norm_eps
        self.qkv_bias = qkv_bias
        self.mlp_ratio = mlp_ratio
        self.use_abs_pos = use_abs_pos
        self.use_rel_pos = use_rel_pos
        self.rel_pos_zero_init = rel_pos_zero_init
        self.window_size = window_size
        self.global_attn_indexes = global_attn_indexes
        self.num_pos_feats = num_pos_feats
        self.mlp_dim = mlp_dim
        self.batch_size = batch_size

        # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
        num_patches = (image_size // patch_size) ** 2
        self.seq_length = num_patches + 1

        self.prompt_encoder_tester = SamPromptEncoderTester()
        self.mask_decoder_tester = SamMaskDecoderTester()

    def prepare_config_and_inputs(self):
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
        config = self.get_config()

        return config, pixel_values

    def get_config(self):
        vision_config = SamVisionConfig(
            image_size=self.image_size,
            patch_size=self.patch_size,
            num_channels=self.num_channels,
            hidden_size=self.hidden_size,
            projection_dim=self.projection_dim,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            dropout=self.dropout,
            attention_dropout=self.attention_dropout,
            initializer_range=self.initializer_range,
            initializer_factor=self.initializer_factor,
            output_channels=self.output_channels,
            qkv_bias=self.qkv_bias,
            mlp_ratio=self.mlp_ratio,
            use_abs_pos=self.use_abs_pos,
            use_rel_pos=self.use_rel_pos,
            rel_pos_zero_init=self.rel_pos_zero_init,
            window_size=self.window_size,
            global_attn_indexes=self.global_attn_indexes,
            num_pos_feats=self.num_pos_feats,
            mlp_dim=self.mlp_dim,
        )

        prompt_encoder_config = self.prompt_encoder_tester.get_config()

        mask_decoder_config = self.mask_decoder_tester.get_config()

        return SamConfig(
            vision_config=vision_config,
            prompt_encoder_config=prompt_encoder_config,
            mask_decoder_config=mask_decoder_config,
        )

    def create_and_check_model(self, config, pixel_values):
        model = SamModel(config=config)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            result = model(pixel_values)
        self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
        self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))

    def create_and_check_get_image_features(self, config, pixel_values):
        model = SamModel(config=config)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            result = model.get_image_embeddings(pixel_values)
        self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))

    def create_and_check_get_image_hidden_states(self, config, pixel_values):
        model = SamModel(config=config)
        model.to(torch_device)
        model.eval()
        with torch.no_grad():
            result = model.vision_encoder(
                pixel_values,
                output_hidden_states=True,
                return_dict=True,
            )

        # after computing the convolutional features
        expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
        self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
        self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)

        with torch.no_grad():
            result = model.vision_encoder(
                pixel_values,
                output_hidden_states=True,
                return_dict=False,
            )

        # after computing the convolutional features
        expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
        self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
        self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        config, pixel_values = config_and_inputs
        inputs_dict = {"pixel_values": pixel_values}
        return config, inputs_dict


@require_torch
287
class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
288
289
290
291
292
293
    """
    Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
    """

    all_model_classes = (SamModel,) if is_torch_available() else ()
294
295
296
    pipeline_model_mapping = (
        {"feature-extraction": SamModel, "mask-generation": SamModel} if is_torch_available() else {}
    )
297
298
299
300
301
302
    fx_compatible = False
    test_pruning = False
    test_resize_embeddings = False
    test_head_masking = False
    test_torchscript = False

303
304
305
306
307
308
    # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
    def is_pipeline_test_to_skip(
        self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
    ):
        return True

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    def setUp(self):
        self.model_tester = SamModelTester(self)
        self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
        self.prompt_encoder_config_tester = ConfigTester(
            self,
            config_class=SamPromptEncoderConfig,
            has_text_modality=False,
            num_attention_heads=12,
            num_hidden_layers=2,
        )
        self.mask_decoder_config_tester = ConfigTester(
            self, config_class=SamMaskDecoderConfig, has_text_modality=False
        )

    def test_config(self):
        self.vision_config_tester.run_common_tests()
        self.prompt_encoder_config_tester.run_common_tests()
        self.mask_decoder_config_tester.run_common_tests()

    @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
    def test_inputs_embeds(self):
        pass

    def test_model_common_attributes(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
            x = model.get_output_embeddings()
            self.assertTrue(x is None or isinstance(x, nn.Linear))

    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.forward)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

            expected_arg_names = ["pixel_values"]
            self.assertListEqual(arg_names[:1], expected_arg_names)

    def test_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_model(*config_and_inputs)

    def test_get_image_features(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_get_image_features(*config_and_inputs)

    def test_image_hidden_states(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)

    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        expected_vision_attention_shape = (
            self.model_tester.batch_size * self.model_tester.num_attention_heads,
            196,
            196,
        )
        expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
            config.return_dict = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))

            vision_attentions = outputs.vision_attentions
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)

            mask_decoder_attentions = outputs.mask_decoder_attentions
            self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
            vision_attentions = outputs.vision_attentions
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)

            mask_decoder_attentions = outputs.mask_decoder_attentions
            self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)

            self.assertListEqual(
                list(vision_attentions[0].shape[-4:]),
                list(expected_vision_attention_shape),
            )

            self.assertListEqual(
                list(mask_decoder_attentions[0].shape[-4:]),
                list(expected_mask_decoder_attention_shape),
            )

    @unittest.skip(reason="SamModel does not support training")
    def test_training(self):
        pass

    @unittest.skip(reason="SamModel does not support training")
    def test_training_gradient_checkpointing(self):
        pass

424
425
426
427
428
429
430
431
432
433
434
435
    @unittest.skip(
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
    )
    def test_training_gradient_checkpointing_use_reentrant(self):
        pass

    @unittest.skip(
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
    )
    def test_training_gradient_checkpointing_use_reentrant_false(self):
        pass

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING")
    def test_save_load_fast_init_from_base(self):
        pass

    @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING")
    def test_save_load_fast_init_to_base(self):
        pass

    @unittest.skip(reason="SamModel does not support training")
    def test_retain_grad_hidden_states_attentions(self):
        pass

    @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
    def test_hidden_states_output(self):
        pass

452
453
454
    def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
        # Use a slightly higher default tol to make the tests non-flaky
        super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
455

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    @slow
    def test_model_from_pretrained(self):
        for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            model = SamModel.from_pretrained(model_name)
            self.assertIsNotNone(model)


def prepare_image():
    img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
    return raw_image


def prepare_dog_img():
    img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
    return raw_image


@slow
class SamModelIntegrationTest(unittest.TestCase):
Yih-Dar's avatar
Yih-Dar committed
477
478
479
480
    def tearDown(self):
        super().tearDown()
        # clean-up as much as possible GPU memory occupied by PyTorch
        gc.collect()
481
        backend_empty_cache(torch_device)
Yih-Dar's avatar
Yih-Dar committed
482

483
    def test_inference_mask_generation_no_point(self):
484
485
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
486
487
488
489
490
491
492
493
494
495

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()
        inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
496
        masks = outputs.pred_masks[0, 0, 0, 0, :3]
497
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
Yih-Dar's avatar
Yih-Dar committed
498
        self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
499
500

    def test_inference_mask_generation_one_point_one_bb(self):
501
502
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
503
504
505
506
507

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()
508
        input_boxes = [[[650, 900, 1000, 1250]]]
509
510
511
512
513
514
515
516
517
        input_points = [[[820, 1080]]]

        inputs = processor(
            images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
518
        masks = outputs.pred_masks[0, 0, 0, 0, :3]
519
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
520
        self.assertTrue(
Yih-Dar's avatar
Yih-Dar committed
521
            torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
522
        )
523
524

    def test_inference_mask_generation_batched_points_batched_images(self):
525
526
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()
        input_points = [
            [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
            [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
        ]

        inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to(
            torch_device
        )

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze().cpu()
544
        masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
545
546
547
548

        EXPECTED_SCORES = torch.tensor(
            [
                [
549
550
551
552
                    [0.6765, 0.9379, 0.8803],
                    [0.6765, 0.9379, 0.8803],
                    [0.6765, 0.9379, 0.8803],
                    [0.6765, 0.9379, 0.8803],
553
554
                ],
                [
555
556
557
558
                    [0.3317, 0.7264, 0.7646],
                    [0.6765, 0.9379, 0.8803],
                    [0.6765, 0.9379, 0.8803],
                    [0.6765, 0.9379, 0.8803],
559
560
561
                ],
            ]
        )
Yih-Dar's avatar
Yih-Dar committed
562
        EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625])
563
        self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
564
        self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
565
566

    def test_inference_mask_generation_one_point_one_bb_zero(self):
567
568
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
569
570
571
572
573

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()
574
        input_boxes = [[[620, 900, 1000, 1255]]]
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        input_points = [[[820, 1080]]]
        labels = [[0]]

        inputs = processor(
            images=raw_image,
            input_boxes=input_boxes,
            input_points=input_points,
            input_labels=labels,
            return_tensors="pt",
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()

Yih-Dar's avatar
Yih-Dar committed
590
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))
591
592

    def test_inference_mask_generation_one_point(self):
593
594
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()

        input_points = [[[400, 650]]]
        input_labels = [[1]]

        inputs = processor(
            images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
611
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
612
613
614
615
616
617
618
619
620

        # With no label
        input_points = [[[400, 650]]]

        inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
621
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
622
623

    def test_inference_mask_generation_two_points(self):
624
625
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()

        input_points = [[[400, 650], [800, 650]]]
        input_labels = [[1, 1]]

        inputs = processor(
            images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
642
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
643
644
645
646
647
648
649
650

        # no labels
        inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()

651
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
652
653

    def test_inference_mask_generation_two_points_batched(self):
654
655
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()

        input_points = [[[400, 650], [800, 650]], [[400, 650]]]
        input_labels = [[1, 1], [1]]

        inputs = processor(
            images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt"
        ).to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
672
673
        self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
        self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))
674
675

    def test_inference_mask_generation_one_box(self):
676
677
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
678
679
680
681
682
683
684
685
686
687
688
689
690

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()

        input_boxes = [[[75, 275, 1725, 850]]]

        inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores = outputs.iou_scores.squeeze()
691
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))
692
693

    def test_inference_mask_generation_batched_image_one_point(self):
694
695
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()
        raw_dog_image = prepare_dog_img()

        input_points = [[[820, 1080]], [[220, 470]]]

        inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to(
            torch_device
        )

        with torch.no_grad():
            outputs = model(**inputs)
        scores_batched = outputs.iou_scores.squeeze()

        input_points = [[[220, 470]]]

        inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)
        scores_single = outputs.iou_scores.squeeze()
        self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))

    def test_inference_mask_generation_two_points_point_batch(self):
723
724
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()

        # fmt: off
        input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu()
        # fmt: on

        input_points = input_points.unsqueeze(0)

        inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)

        iou_scores = outputs.iou_scores.cpu()
        self.assertTrue(iou_scores.shape == (1, 2, 3))
        torch.testing.assert_allclose(
745
            iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
746
747
748
        )

    def test_inference_mask_generation_three_boxes_point_batch(self):
749
750
        model = SamModel.from_pretrained("facebook/sam-vit-base")
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
751
752
753
754
755
756
757
758

        model.to(torch_device)
        model.eval()

        raw_image = prepare_image()

        # fmt: off
        input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]],  [[75, 275, 1725, 850]]]).cpu()
759
760
761
        EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
         [0.5996, 0.7661, 0.7937],
         [0.5996, 0.7661, 0.7937]]])
762
763
764
765
766
767
768
769
770
771
772
        # fmt: on
        input_boxes = input_boxes.unsqueeze(0)

        inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)

        with torch.no_grad():
            outputs = model(**inputs)

        iou_scores = outputs.iou_scores.cpu()
        self.assertTrue(iou_scores.shape == (1, 3, 3))
        torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
773
774

    def test_dummy_pipeline_generation(self):
775
        generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
776
777
778
        raw_image = prepare_image()

        _ = generator(raw_image, points_per_batch=64)