Unverified Commit 1c460a52 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF port of the Segment Anything Model (SAM) (#22970)



* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8aa8513f
...@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ | | RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ | | RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ | | SAM | ❌ | ❌ | ✅ | | ❌ |
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ | | SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
......
...@@ -99,3 +99,9 @@ Resources: ...@@ -99,3 +99,9 @@ Resources:
[[autodoc]] SamModel [[autodoc]] SamModel
- forward - forward
## TFSamModel
[[autodoc]] TFSamModel
- call
\ No newline at end of file
...@@ -3406,6 +3406,13 @@ else: ...@@ -3406,6 +3406,13 @@ else:
"TFRoFormerPreTrainedModel", "TFRoFormerPreTrainedModel",
] ]
) )
_import_structure["models.sam"].extend(
[
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
)
_import_structure["models.segformer"].extend( _import_structure["models.segformer"].extend(
[ [
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -6657,6 +6664,11 @@ if TYPE_CHECKING: ...@@ -6657,6 +6664,11 @@ if TYPE_CHECKING:
TFRoFormerModel, TFRoFormerModel,
TFRoFormerPreTrainedModel, TFRoFormerPreTrainedModel,
) )
from .models.sam import (
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSamModel,
TFSamPreTrainedModel,
)
from .models.segformer import ( from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead, TFSegformerDecodeHead,
......
...@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("roberta", "TFRobertaModel"), ("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"), ("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"), ("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"), ("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"), ("swin", "TFSwinModel"),
...@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( ...@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
("mobilebert", "TFMobileBertForNextSentencePrediction"), ("mobilebert", "TFMobileBertForNextSentencePrediction"),
] ]
) )
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
...@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( ...@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
) )
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModel(_BaseAutoModelClass): class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING _model_mapping = TF_MODEL_MAPPING
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = { _import_structure = {
...@@ -39,6 +45,17 @@ else: ...@@ -39,6 +45,17 @@ else:
"SamModel", "SamModel",
"SamPreTrainedModel", "SamPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try: try:
if not is_vision_available(): if not is_vision_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -66,6 +83,14 @@ if TYPE_CHECKING: ...@@ -66,6 +83,14 @@ if TYPE_CHECKING:
else: else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
try: try:
if not is_vision_available(): if not is_vision_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput): ...@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
class SamPatchEmbeddings(nn.Module): class SamPatchEmbeddings(nn.Module):
""" """
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
...@@ -198,7 +197,7 @@ class SamAttention(nn.Module): ...@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
values. values.
""" """
def __init__(self, config, downsample_rate=None) -> None: def __init__(self, config, downsample_rate=None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -252,7 +251,7 @@ class SamAttention(nn.Module): ...@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
class SamTwoWayAttentionBlock(nn.Module): class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None: def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
""" """
A transformer block with four layers: A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
...@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module): ...@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
the embeddings of the mask inputs the embeddings of the mask inputs
multimask_output (bool): multimask_output (bool):
Whether to return multiple masks or a single mask. Whether to return multiple masks or a single mask.
output_attentions (bool, **optional**): output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers. Whether or not to return the attentions tensors of all attention layers.
""" """
batch_size, num_channels, height, width = image_embeddings.shape batch_size, num_channels, height, width = image_embeddings.shape
...@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module): ...@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings. Embeds different types of prompts, returning both sparse and dense embeddings.
Args: Args:
points (`torch.Tensor`, **optionnal**): points (`torch.Tensor`, *optional*):
point coordinates and labels to embed. point coordinates and labels to embed.
boxes (`torch.Tensor`, **optionnal**): boxes (`torch.Tensor`, *optional*):
boxes to embed boxes to embed
masks (`torch.Tensor`, **optionnal**): masks (`torch.Tensor`, *optional*):
masks to embed masks to embed
""" """
sparse_embeddings = None sparse_embeddings = None
...@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module): ...@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
class SamVisionAttention(nn.Module): class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings.""" """Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size) -> None: def __init__(self, config, window_size):
super().__init__() super().__init__()
input_size = ( input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size) (config.image_size // config.patch_size, config.image_size // config.patch_size)
...@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module): ...@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
class SamVisionLayer(nn.Module): class SamVisionLayer(nn.Module):
def __init__(self, config, window_size) -> None: def __init__(self, config, window_size):
super().__init__() super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size) self.attn = SamVisionAttention(config, window_size)
...@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r""" ...@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
class SamModel(SamPreTrainedModel): class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config) -> None: def __init__(self, config):
super().__init__(config) super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
...@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel): ...@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None vision_attentions = None
mask_decoder_attentions = None
vision_hidden_states = None vision_hidden_states = None
if pixel_values is not None: if pixel_values is not None:
...@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel): ...@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
"The batch size of the image embeddings and the input points must be the same. ", "The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ", " if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)", " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
) )
sparse_embeddings, dense_embeddings = self.prompt_encoder( sparse_embeddings, dense_embeddings = self.prompt_encoder(
......
This diff is collapsed.
...@@ -22,12 +22,15 @@ import numpy as np ...@@ -22,12 +22,15 @@ import numpy as np
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_torch_available from ...utils import TensorType, is_tf_available, is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
if is_tf_available():
import tensorflow as tf
class SamProcessor(ProcessorMixin): class SamProcessor(ProcessorMixin):
r""" r"""
...@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin): ...@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin):
# pop arguments that are not used in the foward but used nevertheless # pop arguments that are not used in the foward but used nevertheless
original_sizes = encoding_image_processor["original_sizes"] original_sizes = encoding_image_processor["original_sizes"]
if isinstance(original_sizes, torch.Tensor): if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
original_sizes = original_sizes.numpy() original_sizes = original_sizes.numpy()
input_points, input_labels, input_boxes = self._check_and_preprocess_points( input_points, input_labels, input_boxes = self._check_and_preprocess_points(
...@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin): ...@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin):
input_boxes = torch.from_numpy(input_boxes) input_boxes = torch.from_numpy(input_boxes)
# boxes batch size of 1 by default # boxes batch size of 1 by default
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
elif return_tensors == "tf":
input_boxes = tf.convert_to_tensor(input_boxes)
# boxes batch size of 1 by default
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
encoding_image_processor.update({"input_boxes": input_boxes}) encoding_image_processor.update({"input_boxes": input_boxes})
if input_points is not None: if input_points is not None:
if return_tensors == "pt": if return_tensors == "pt":
input_points = torch.from_numpy(input_points) input_points = torch.from_numpy(input_points)
# point batch size of 1 by default # point batch size of 1 by default
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
elif return_tensors == "tf":
input_points = tf.convert_to_tensor(input_points)
# point batch size of 1 by default
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
encoding_image_processor.update({"input_points": input_points}) encoding_image_processor.update({"input_points": input_points})
if input_labels is not None: if input_labels is not None:
if return_tensors == "pt": if return_tensors == "pt":
input_labels = torch.from_numpy(input_labels) input_labels = torch.from_numpy(input_labels)
# point batch size of 1 by default # point batch size of 1 by default
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
elif return_tensors == "tf":
input_labels = tf.convert_to_tensor(input_labels)
# point batch size of 1 by default
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
encoding_image_processor.update({"input_labels": input_labels}) encoding_image_processor.update({"input_labels": input_labels})
return encoding_image_processor return encoding_image_processor
...@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin): ...@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin):
it is converted to a `numpy.ndarray` and then to a `list`. it is converted to a `numpy.ndarray` and then to a `list`.
""" """
if input_points is not None: if input_points is not None:
if isinstance(input_points, torch.Tensor): if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
input_points = input_points.numpy().tolist() input_points = input_points.numpy().tolist()
if not isinstance(input_points, list) or not isinstance(input_points[0], list): if not isinstance(input_points, list) or not isinstance(input_points[0], list):
...@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin): ...@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin):
input_points = None input_points = None
if input_labels is not None: if input_labels is not None:
if isinstance(input_labels, torch.Tensor): if hasattr(input_labels, "numpy"):
input_labels = input_labels.numpy().tolist() input_labels = input_labels.numpy().tolist()
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
...@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin): ...@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin):
input_labels = None input_labels = None
if input_boxes is not None: if input_boxes is not None:
if isinstance(input_boxes, torch.Tensor): if hasattr(input_boxes, "numpy"):
input_boxes = input_boxes.numpy().tolist() input_boxes = input_boxes.numpy().tolist()
if ( if (
......
...@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional ...@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
# This is a very simplified functional layernorm, designed to duplicate
# the functionality of PyTorch nn.functional.layer_norm when this is needed to port
# models in Transformers.
if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
# Get mean and variance on the axis to be normalized
mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
if axis != -1:
# Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
# on every dimension except axis
shape = [1] * inputs.shape.rank
shape[axis] = shape_list(inputs)[axis]
weight = tf.reshape(weight, shape)
bias = tf.reshape(bias, shape)
# Compute layer normalization using the batch_normalization
# function.
outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset=bias,
scale=weight,
variance_epsilon=epsilon,
)
return outputs
def flatten(input, start_dim=0, end_dim=-1):
# Replicates the behavior of torch.flatten in TF
# If end_dim or start_dim is negative, count them from the end
if end_dim < 0:
end_dim += input.shape.rank
if start_dim < 0:
start_dim += input.shape.rank
if start_dim == end_dim:
return input
in_shape = tf.shape(input)
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
return tf.reshape(input, out_shape)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
""" """
Invert an attention mask (e.g., switches 0. and 1.). Invert an attention mask (e.g., switches 0. and 1.).
......
...@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject): ...@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSamModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSamPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self): def test_hidden_states_output(self):
pass pass
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self): def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-huge")
...@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4)) self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
)
def test_inference_mask_generation_batched_points_batched_images(self): def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-huge")
...@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze().cpu() scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
EXPECTED_SCORES = torch.tensor( EXPECTED_SCORES = torch.tensor(
[ [
...@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase): ...@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
], ],
] ]
) )
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
def test_inference_mask_generation_one_point_one_bb_zero(self): def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-huge")
......
This diff is collapsed.
...@@ -17,8 +17,14 @@ import unittest ...@@ -17,8 +17,14 @@ import unittest
import numpy as np import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision from transformers.testing_utils import (
from transformers.utils import is_torch_available, is_vision_available is_pt_tf_cross_test,
require_tf,
require_torch,
require_torchvision,
require_vision,
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
if is_vision_available(): if is_vision_available():
...@@ -29,6 +35,9 @@ if is_vision_available(): ...@@ -29,6 +35,9 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
if is_tf_available():
import tensorflow as tf
@require_vision @require_vision
@require_torchvision @require_torchvision
...@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase): ...@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase):
dummy_masks = [[1, 0], [0, 1]] dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
@require_vision
@require_tf
class TFSamProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, SamImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor
input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
@require_tf
def test_post_process_masks(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = [tf.ones((1, 3, 5, 5))]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf")
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
masks = processor.post_process_masks(
dummy_masks,
tf.convert_to_tensor(original_sizes),
tf.convert_to_tensor(reshaped_input_size),
return_tensors="tf",
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
# should also work with np
dummy_masks = [np.ones((1, 3, 5, 5))]
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(tf.errors.InvalidArgumentError):
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
@require_vision
@require_torchvision
class SamProcessorEquivalenceTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
@is_pt_tf_cross_test
def test_post_process_masks_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
pt_dummy_masks = [torch.tensor(dummy_masks)]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
tf_masks = processor.post_process_masks(
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
)
pt_masks = processor.post_process_masks(
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
)
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
@is_pt_tf_cross_test
def test_image_processor_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment