"tests/test_modeling_tf_speech_to_text.py" did not exist on "1417978cd49181fd08837e7722c34dd5c8c113e3"
Unverified Commit 1c460a52 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF port of the Segment Anything Model (SAM) (#22970)



* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8aa8513f
......@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -99,3 +99,9 @@ Resources:
[[autodoc]] SamModel
- forward
## TFSamModel
[[autodoc]] TFSamModel
- call
\ No newline at end of file
......@@ -3406,6 +3406,13 @@ else:
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.sam"].extend(
[
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
)
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -6657,6 +6664,11 @@ if TYPE_CHECKING:
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.sam import (
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSamModel,
TFSamPreTrainedModel,
)
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
......
......@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
......@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
......@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
......
......@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
......@@ -39,6 +45,17 @@ else:
"SamModel",
"SamPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
......@@ -66,6 +83,14 @@ if TYPE_CHECKING:
else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
......
......@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
class SamPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
......@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
values.
"""
def __init__(self, config, downsample_rate=None) -> None:
def __init__(self, config, downsample_rate=None):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
......@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, **optional**):
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
......@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, **optionnal**):
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, **optionnal**):
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, **optionnal**):
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
......@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
......@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
......@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config) -> None:
def __init__(self, config):
super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
......@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
mask_decoder_attentions = None
vision_hidden_states = None
if pixel_values is not None:
......@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
"The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
......
This diff is collapsed.
......@@ -22,12 +22,15 @@ import numpy as np
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_torch_available
from ...utils import TensorType, is_tf_available, is_torch_available
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
class SamProcessor(ProcessorMixin):
r"""
......@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin):
# pop arguments that are not used in the foward but used nevertheless
original_sizes = encoding_image_processor["original_sizes"]
if isinstance(original_sizes, torch.Tensor):
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
original_sizes = original_sizes.numpy()
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
......@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin):
input_boxes = torch.from_numpy(input_boxes)
# boxes batch size of 1 by default
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
elif return_tensors == "tf":
input_boxes = tf.convert_to_tensor(input_boxes)
# boxes batch size of 1 by default
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
encoding_image_processor.update({"input_boxes": input_boxes})
if input_points is not None:
if return_tensors == "pt":
input_points = torch.from_numpy(input_points)
# point batch size of 1 by default
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
elif return_tensors == "tf":
input_points = tf.convert_to_tensor(input_points)
# point batch size of 1 by default
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
encoding_image_processor.update({"input_points": input_points})
if input_labels is not None:
if return_tensors == "pt":
input_labels = torch.from_numpy(input_labels)
# point batch size of 1 by default
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
elif return_tensors == "tf":
input_labels = tf.convert_to_tensor(input_labels)
# point batch size of 1 by default
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
encoding_image_processor.update({"input_labels": input_labels})
return encoding_image_processor
......@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin):
it is converted to a `numpy.ndarray` and then to a `list`.
"""
if input_points is not None:
if isinstance(input_points, torch.Tensor):
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
input_points = input_points.numpy().tolist()
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
......@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin):
input_points = None
if input_labels is not None:
if isinstance(input_labels, torch.Tensor):
if hasattr(input_labels, "numpy"):
input_labels = input_labels.numpy().tolist()
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
......@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin):
input_labels = None
if input_boxes is not None:
if isinstance(input_boxes, torch.Tensor):
if hasattr(input_boxes, "numpy"):
input_boxes = input_boxes.numpy().tolist()
if (
......
......@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
# This is a very simplified functional layernorm, designed to duplicate
# the functionality of PyTorch nn.functional.layer_norm when this is needed to port
# models in Transformers.
if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
# Get mean and variance on the axis to be normalized
mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
if axis != -1:
# Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
# on every dimension except axis
shape = [1] * inputs.shape.rank
shape[axis] = shape_list(inputs)[axis]
weight = tf.reshape(weight, shape)
bias = tf.reshape(bias, shape)
# Compute layer normalization using the batch_normalization
# function.
outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset=bias,
scale=weight,
variance_epsilon=epsilon,
)
return outputs
def flatten(input, start_dim=0, end_dim=-1):
# Replicates the behavior of torch.flatten in TF
# If end_dim or start_dim is negative, count them from the end
if end_dim < 0:
end_dim += input.shape.rank
if start_dim < 0:
start_dim += input.shape.rank
if start_dim == end_dim:
return input
in_shape = tf.shape(input)
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
return tf.reshape(input, out_shape)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
......
......@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSamModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSamPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self):
pass
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
@slow
def test_model_from_pretrained(self):
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
......@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
)
def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
......@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
EXPECTED_SCORES = torch.tensor(
[
......@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
],
]
)
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
......
This diff is collapsed.
......@@ -17,8 +17,14 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_tf,
require_torch,
require_torchvision,
require_vision,
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
if is_vision_available():
......@@ -29,6 +35,9 @@ if is_vision_available():
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
@require_vision
@require_torchvision
......@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase):
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
@require_vision
@require_tf
class TFSamProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, SamImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor
input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
@require_tf
def test_post_process_masks(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = [tf.ones((1, 3, 5, 5))]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf")
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
masks = processor.post_process_masks(
dummy_masks,
tf.convert_to_tensor(original_sizes),
tf.convert_to_tensor(reshaped_input_size),
return_tensors="tf",
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
# should also work with np
dummy_masks = [np.ones((1, 3, 5, 5))]
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(tf.errors.InvalidArgumentError):
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
@require_vision
@require_torchvision
class SamProcessorEquivalenceTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
@is_pt_tf_cross_test
def test_post_process_masks_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
pt_dummy_masks = [torch.tensor(dummy_masks)]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
tf_masks = processor.post_process_masks(
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
)
pt_masks = processor.post_process_masks(
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
)
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
@is_pt_tf_cross_test
def test_image_processor_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment