Unverified Commit 1c460a52 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF port of the Segment Anything Model (SAM) (#22970)



* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8aa8513f
......@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -99,3 +99,9 @@ Resources:
[[autodoc]] SamModel
- forward
## TFSamModel
[[autodoc]] TFSamModel
- call
\ No newline at end of file
......@@ -3406,6 +3406,13 @@ else:
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.sam"].extend(
[
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
)
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -6657,6 +6664,11 @@ if TYPE_CHECKING:
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.sam import (
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSamModel,
TFSamPreTrainedModel,
)
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
......
......@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
......@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
......@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
......
......@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
......@@ -39,6 +45,17 @@ else:
"SamModel",
"SamPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
......@@ -66,6 +83,14 @@ if TYPE_CHECKING:
else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
......
......@@ -34,7 +34,14 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
from ...utils import (
TensorType,
is_tf_available,
is_torch_available,
is_torchvision_available,
logging,
requires_backends,
)
if is_torch_available():
......@@ -44,6 +51,12 @@ if is_torch_available():
if is_torchvision_available():
from torchvision.ops.boxes import batched_nms
if is_tf_available():
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from ...tf_utils import flatten, shape_list
logger = logging.get_logger(__name__)
......@@ -372,6 +385,61 @@ class SamImageProcessor(BaseImageProcessor):
return encoded_outputs
def post_process_masks(
self,
masks,
original_sizes,
reshaped_input_sizes,
mask_threshold=0.0,
binarize=True,
pad_size=None,
return_tensors="pt",
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
return_tensors (`str`, *optional*, defaults to `"pt"`):
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
Returns:
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
(height, width) is given by original_size.
"""
if return_tensors == "pt":
return self._post_process_masks_pt(
masks=masks,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
mask_threshold=mask_threshold,
binarize=binarize,
pad_size=pad_size,
)
elif return_tensors == "tf":
return self._post_process_masks_tf(
masks=masks,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
mask_threshold=mask_threshold,
binarize=binarize,
pad_size=pad_size,
)
else:
raise ValueError("return_tensors must be either 'pt' or 'tf'")
def _post_process_masks_pt(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
......@@ -418,21 +486,70 @@ class SamImageProcessor(BaseImageProcessor):
return output_masks
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
def _post_process_masks_tf(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`tf.Tensor`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`tf.Tensor`):
The original size of the images before resizing for input to the model, in (height, width) format.
reshaped_input_sizes (`tf.Tensor`):
The size of the image input to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
given by original_size.
"""
requires_backends(self, ["tf"])
pad_size = self.pad_size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
output_masks = []
for i, original_size in enumerate(original_sizes):
# tf.image expects NHWC, we transpose the NCHW inputs for it
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
# And then we transpose them back at the end
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
return output_masks
def post_process_for_mask_generation(
self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
):
"""
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
Args:
all_masks (`List[torch.Tensor]`):
all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all predicted segmentation masks
all_scores (`List[torch.Tensor]`):
all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all predicted iou scores
all_boxes (`List[torch.Tensor]`):
all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all bounding boxes of the predicted masks
crops_nms_thresh (`float`):
Threshold for NMS (Non Maximum Suppression) algorithm.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
if return_tensors == "pt":
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
elif return_tensors == "tf":
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
def generate_crop_boxes(
self,
......@@ -443,6 +560,7 @@ class SamImageProcessor(BaseImageProcessor):
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
return_tensors: str = "pt",
):
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
......@@ -464,10 +582,35 @@ class SamImageProcessor(BaseImageProcessor):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*, defaults to None):
Device to use for the computation. If None, cpu will be used.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
return _generate_crop_boxes(
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
image,
target_size,
crop_n_layers,
overlap_ratio,
points_per_crop,
crop_n_points_downscale_factor,
)
if return_tensors == "pt":
if device is None:
device = torch.device("cpu")
crop_boxes = torch.tensor(crop_boxes, device=device)
points_per_crop = torch.tensor(points_per_crop, device=device)
# cropped_images stays as np
input_labels = torch.tensor(input_labels, device=device)
elif return_tensors == "tf":
if device is not None:
raise ValueError("device is not a supported argument when return_tensors is tf!")
crop_boxes = tf.convert_to_tensor(crop_boxes)
points_per_crop = tf.convert_to_tensor(points_per_crop)
# cropped_images stays as np
input_labels = tf.convert_to_tensor(input_labels)
else:
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
return crop_boxes, points_per_crop, cropped_images, input_labels
def filter_masks(
self,
......@@ -479,6 +622,67 @@ class SamImageProcessor(BaseImageProcessor):
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
return_tensors="pt",
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
Args:
masks (`Union[torch.Tensor, tf.Tensor]`):
Input masks.
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the orginal image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
if return_tensors == "pt":
return self._filter_masks_pt(
masks=masks,
iou_scores=iou_scores,
original_size=original_size,
cropped_box_image=cropped_box_image,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
mask_threshold=mask_threshold,
stability_score_offset=stability_score_offset,
)
elif return_tensors == "tf":
return self._filter_masks_tf(
masks=masks,
iou_scores=iou_scores,
original_size=original_size,
cropped_box_image=cropped_box_image,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
mask_threshold=mask_threshold,
stability_score_offset=stability_score_offset,
)
def _filter_masks_pt(
self,
masks,
iou_scores,
original_size,
cropped_box_image,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
......@@ -525,7 +729,7 @@ class SamImageProcessor(BaseImageProcessor):
# compute stability score
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
......@@ -549,8 +753,85 @@ class SamImageProcessor(BaseImageProcessor):
return masks, scores, converted_boxes
def _filter_masks_tf(
self,
masks,
iou_scores,
original_size,
cropped_box_image,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
Args:
masks (`tf.Tensor`):
Input masks.
iou_scores (`tf.Tensor`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the orginal image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
"""
requires_backends(self, ["tf"])
original_height, original_width = original_size
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
if masks.shape[0] != iou_scores.shape[0]:
raise ValueError("masks and iou_scores must have the same batch size.")
batch_size = masks.shape[0]
keep_mask = tf.ones(batch_size, dtype=tf.bool)
if pred_iou_thresh > 0.0:
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
# compute stability score
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
masks = masks[keep_mask]
# binarize masks
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box_tf(masks)
keep_mask = ~_is_box_near_crop_edge_tf(
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
)
scores = scores[keep_mask]
masks = masks[keep_mask]
converted_boxes = converted_boxes[keep_mask]
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
# conversion to rle is necessary to run non-maximum suppresion
masks = _mask_to_rle_tf(masks)
return masks, scores, converted_boxes
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
# One mask is always contained inside the other.
# Save memory by preventing unnecesary cast to torch.int64
intersections = (
......@@ -561,6 +842,17 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi
return stability_scores
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
# Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
# we get the right division results.
intersections = tf.count_nonzero(
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
)
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
stability_scores = intersections / unions
return stability_scores
def _build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
......@@ -606,7 +898,6 @@ def _generate_crop_boxes(
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
......@@ -626,11 +917,7 @@ def _generate_crop_boxes(
Number of points to sample per crop.
crop_n_points_downscale_factor (`int`, *optional*):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*):
Device to run the crop generation on. Defaults to CPU.
"""
if device is None:
device = torch.device("cpu")
if isinstance(image, list):
raise ValueError("Only one image is allowed for crop generation.")
......@@ -648,12 +935,11 @@ def _generate_crop_boxes(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
)
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
point_grid_per_crop = np.array([point_grid_per_crop])
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
crop_boxes = crop_boxes.astype(np.float32)
points_per_crop = np.array([point_grid_per_crop])
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
return crop_boxes, points_per_crop, cropped_images, input_labels
......@@ -730,6 +1016,16 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
return torch.nn.functional.pad(masks, pad, value=0)
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
left, top, right, bottom = crop_box
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
pad = (left, pad_x - left, top, pad_y - top)
return tf.pad(masks, pad, constant_values=0)
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
......@@ -748,6 +1044,24 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
return torch.any(near_crop_edge, dim=1)
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
left, top, _, _ = crop_box
offset = tf.convert_to_tensor([[left, top, left, top]])
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = tf.expand_dims(offset, 1)
boxes = tf.cast(boxes + offset, tf.float32)
near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
return tf.reduce_any(near_crop_edge, axis=1)
def _batched_mask_to_box(masks: "torch.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
......@@ -797,6 +1111,54 @@ def _batched_mask_to_box(masks: "torch.Tensor"):
return out
def _batched_mask_to_box_tf(masks: "tf.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
corresponds the following required indices:
- LEFT: left hand side of the bounding box
- TOP: top of the bounding box
- RIGHT: right of the bounding box
- BOTTOM: bottom of the bounding box
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
is channel_1 x channel_2 x ... x 4.
Args:
- masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
"""
if tf.size(masks) == 0:
return tf.zeros([*masks.shape[:-2], 4])
# Normalize shape to Cxheightxwidth
shape = shape_list(masks)
height, width = shape[-2:]
# Get top and bottom edges
in_height = tf.reduce_max(masks, axis=-1)
in_height_coords = in_height * tf.range(height)[None, :]
bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
in_height_coords = in_height_coords + height * (~in_height)
top_edges = tf.reduce_min(in_height_coords, axis=-1)
# Get left and right edges
in_width, _ = tf.reduce_max(masks, axis=-2)
in_width_coords = in_width * tf.range(width)[None, :]
right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
in_width_coords = in_width_coords + width * (~in_width)
left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
out = out * tf.expand_dims(~empty_filter, -1)
# Return to original shape
out = tf.reshape(out, *shape[:-2], 4)
return out
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
"""
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
......@@ -820,6 +1182,29 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
return out
def _mask_to_rle_tf(input_mask: "tf.Tensor"):
"""
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
"""
# Put in fortran order and flatten height and width
batch_size, height, width = input_mask.shape
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
# Compute change indices
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
change_indices = tf.where(diff)
# Encode run length
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
out.append({"size": [height, width], "counts": counts})
return out
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
height, width = rle["size"]
......@@ -836,7 +1221,7 @@ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maxium Suppression) on the outputs.
Perform NMS (Non Maximum Suppression) on the outputs.
Args:
rle_masks (`torch.Tensor`):
......@@ -861,3 +1246,32 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maximum Suppression) on the outputs.
Args:
rle_masks (`tf.Tensor`):
binary masks in the RLE format
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
iou_scores predicted by the model
mask_boxes (`tf.Tensor`):
The bounding boxes corresponding to segmentation masks
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
NMS threshold.
"""
keep_by_nms = tf.image.combined_non_max_suppression(
boxes=mask_boxes.float(),
scores=iou_scores,
idxs=torch.zeros(mask_boxes.shape[0]),
iou_threshold=amg_crops_nms_thresh,
)
iou_scores = iou_scores[keep_by_nms]
rle_masks = [rle_masks[i] for i in keep_by_nms]
mask_boxes = mask_boxes[keep_by_nms]
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes
......@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
class SamPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
......@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
values.
"""
def __init__(self, config, downsample_rate=None) -> None:
def __init__(self, config, downsample_rate=None):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
......@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, **optional**):
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
......@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, **optionnal**):
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, **optionnal**):
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, **optionnal**):
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
......@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
......@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
......@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config) -> None:
def __init__(self, config):
super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
......@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
mask_decoder_attentions = None
vision_hidden_states = None
if pixel_values is not None:
......@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
"The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
......
# coding=utf-8
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
discrepancy, the original file should be regarded as the 'reference' version.
"""
import collections
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFPreTrainedModel, shape_list, unpack_inputs
from ...tf_utils import flatten, functional_layernorm
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SamConfig"
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/sam-vit-huge",
"facebook/sam-vit-large",
"facebook/sam-vit-base",
# See all SAM models at https://huggingface.co/models?filter=sam
]
@dataclass
class TFSamVisionEncoderOutput(ModelOutput):
"""
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
layer to the pooler_output.
Args:
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[tf.Tensor] = None
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFSamImageSegmentationOutput(ModelOutput):
"""
Base class for Segment-Anything model's output
Args:
iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
The iou scores of the predicted masks.
pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
The predicted low resolutions masks. Needs to be post-processed by the processor
vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
iou_scores: tf.Tensor = None
pred_masks: tf.Tensor = None
vision_hidden_states: Optional[Tuple[tf.Tensor]] = None
vision_attentions: Optional[Tuple[tf.Tensor]] = None
mask_decoder_attentions: Optional[Tuple[tf.Tensor]] = None
class TFSamPatchEmbeddings(tf.keras.layers.Layer):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = tf.keras.layers.Conv2D(
hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
)
def call(self, pixel_values):
batch_size, num_channels, height, width = shape_list(pixel_values)
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
return embeddings
class TFSamMLPBlock(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(config.mlp_dim, name="lin1")
self.lin2 = tf.keras.layers.Dense(config.hidden_size, name="lin2")
self.act = ACT2FN[config.hidden_act]
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
class TFSamLayerNorm(tf.keras.layers.Layer):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
super().__init__(**kwargs)
self.eps = eps
self.data_format = data_format
self.normalized_shape = normalized_shape
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
def build(self, input_shape):
self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
super().build(input_shape)
def call(self, x: tf.Tensor) -> tf.Tensor:
if self.data_format == "channels_last":
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
elif self.data_format == "channels_first":
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
return x
class TFSamAttention(tf.keras.layers.Layer):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(self, config, downsample_rate=None, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
self.internal_dim = config.hidden_size // downsample_rate
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.q_proj = tf.keras.layers.Dense(self.internal_dim, name="q_proj")
self.k_proj = tf.keras.layers.Dense(self.internal_dim, name="k_proj")
self.v_proj = tf.keras.layers.Dense(self.internal_dim, name="v_proj")
self.out_proj = tf.keras.layers.Dense(self.hidden_size, name="out_proj")
def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
c_per_head = channel // num_attention_heads
hidden_states = tf.reshape(
hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
)
return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape(
hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head)
)
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = shape_list(query)[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = shape_list(query)
attn = tf.matmul(
query, tf.transpose(key, perm=[0, 1, 3, 2])
) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / tf.math.sqrt(float(c_per_head))
attn = tf.nn.softmax(attn, axis=-1)
# Get output
out = tf.matmul(attn, value)
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
sparse inputs (4) cross attention of dense inputs -> sparse inputs
Arguments:
config (`SamMaskDecoderConfig`):
The configuration file used to instantiate the block
attention_downsample_rate (*optionalk*, int, defaults to 2):
The downsample ratio of the block used to reduce the inner dim of the attention.
skip_first_layer_pe (*optional*, bool, defaults to `False`):
Whether or not to skip the addition of the query_point_embedding on the first layer.
"""
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
self.cross_attn_token_to_image = TFSamAttention(
config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
)
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
self.layer_norm4 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
self.cross_attn_image_to_token = TFSamAttention(
config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
)
self.skip_first_layer_pe = skip_first_layer_pe
def call(
self,
queries: tf.Tensor,
keys: tf.Tensor,
query_point_embedding: tf.Tensor,
key_point_embedding: tf.Tensor,
output_attentions: bool = False,
):
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
# Cross attention block, tokens attending to image embedding
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.layer_norm3(queries)
# Cross attention block, image embedding attending to tokens
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
class TFSamTwoWayTransformer(tf.keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
self.layer_norm_final_attn = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
)
def call(
self,
point_embeddings: tf.Tensor,
image_embeddings: tf.Tensor,
image_positional_embeddings: tf.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
# Prepare queries
queries = point_embeddings
keys = image_embeddings
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
class TFSamFeedForward(tf.keras.layers.Layer):
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
):
super().__init__(**kwargs)
self.num_layers = num_layers
self.activation = tf.keras.layers.ReLU()
self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
self.layers = [
tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
for i in range(num_layers - 2)
]
self.sigmoid_output = sigmoid_output
def call(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))
hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
hidden_states = tf.sigmoid(hidden_states)
return hidden_states
class TFSamMaskDecoder(tf.keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.transformer = TFSamTwoWayTransformer(config, name="transformer")
self.upscale_conv1 = tf.keras.layers.Conv2DTranspose(
self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
)
self.upscale_conv2 = tf.keras.layers.Conv2DTranspose(
self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
)
self.upscale_layer_norm = TFSamLayerNorm(
self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
)
self.activation = tf.nn.gelu
mlps_list = []
for i in range(self.num_mask_tokens):
mlps_list += [
TFSamFeedForward(
self.hidden_size,
self.hidden_size,
self.hidden_size // 8,
3,
name=f"output_hypernetworks_mlps_._{i}",
)
]
self.output_hypernetworks_mlps = mlps_list
self.iou_prediction_head = TFSamFeedForward(
self.hidden_size,
config.iou_head_hidden_dim,
self.num_mask_tokens,
config.iou_head_depth,
name="iou_prediction_head",
)
def build(self, input_shape):
self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
self.mask_tokens = self.add_weight(
shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
)
super().build(input_shape)
def call(
self,
image_embeddings: tf.Tensor,
image_positional_embeddings: tf.Tensor,
sparse_prompt_embeddings: tf.Tensor,
dense_prompt_embeddings: tf.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
) -> Tuple[tf.Tensor, tf.Tensor]:
batch_size, num_channels, height, width = shape_list(image_embeddings)
point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
output_tokens = tf.tile(
output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
) # Should be (batch_size, point_size, 5, 32)
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
if sparse_prompt_embeddings.shape[1] != 0:
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
else:
tokens = output_tokens
point_embeddings = tf.cast(tokens, self.iou_token.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1])
image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1])
point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
upscaled_embedding = self.upscale_conv1(image_embeddings)
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
hyper_in_list = []
for i in range(self.num_mask_tokens):
current_mlp = self.output_hypernetworks_mlps[i]
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
hyper_in = tf.stack(hyper_in_list, axis=2)
_, num_channels, height, width = shape_list(upscaled_embedding)
upscaled_embedding = tf.reshape(
upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
)
masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
iou_pred = self.iou_prediction_head(iou_token_out)
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
class TFSamPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.scale = config.hidden_size // 2
self.config = config
def build(self, input_shape):
# TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
self.positional_embedding = self.add_weight(
name="positional_embedding",
shape=(2, self.config.num_pos_feats),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
trainable=False,
)
def call(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1]."""
coordinates = tf.identity(input_coords)
if input_shape is not None:
coordinates = tf.stack(
[
tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
],
axis=-1,
)
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coordinates = 2 * coordinates - 1
coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
coordinates = tf.matmul(coordinates, self.positional_embedding)
coordinates = 2 * np.pi * coordinates
# outputs d_1 x ... x d_n x channel shape
return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
class TFSamMaskEmbedding(tf.keras.layers.Layer):
def __init__(self, config: SamPromptEncoderConfig, **kwargs):
super().__init__(**kwargs)
self.mask_input_channels = config.mask_input_channels // 4
self.activation = ACT2FN[config.hidden_act]
self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
self.conv2 = tf.keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
def call(self, masks):
masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
hidden_states = self.conv1(masks)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.activation(hidden_states)
dense_embeddings = self.conv3(hidden_states)
dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
return dense_embeddings
def build(self, input_shape):
# This class needs an explicit build method because it isn't called with the standard dummy inputs
conv1_shape = [None, None, None, 1]
conv2_shape = [None, None, None, self.mask_input_channels]
conv3_shape = [None, None, None, self.mask_input_channels * 4]
layer_norm1_shape = [None, None, None, self.mask_input_channels]
layer_norm2_shape = [None, None, None, self.mask_input_channels * 4]
with tf.name_scope("conv1"):
self.conv1.build(conv1_shape)
with tf.name_scope("conv2"):
self.conv2.build(conv2_shape)
with tf.name_scope("conv3"):
self.conv3.build(conv3_shape)
with tf.name_scope("layer_norm1"):
self.layer_norm1.build(layer_norm1_shape)
with tf.name_scope("layer_norm2"):
self.layer_norm2.build(layer_norm2_shape)
super().build(input_shape)
class TFSamPromptEncoder(tf.keras.layers.Layer):
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
super().__init__(**kwargs)
self.shared_embedding = shared_patch_embedding
self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
self.no_mask_embed = None
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
self.input_image_size = config.image_size
self.point_embed = []
self.hidden_size = config.hidden_size
self.not_a_point_embed = None
self.config = config
def build(self, input_shape):
self.no_mask_embed = self.add_weight(
name="no_mask_embed.weight",
shape=(1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
self.point_embed = [
self.add_weight(
name=f"point_embed_._{i}.weight",
shape=(1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
for i in range(self.config.num_point_embeddings)
]
self.not_a_point_embed = self.add_weight(
name="not_a_point_embed.weight",
shape=(1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
with tf.name_scope("mask_embed"):
# We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
self.mask_embed.build(
(None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
)
super().build(input_shape)
def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
target_labels_shape = (points.shape[0], points.shape[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2)
labels = tf.concat([labels, padding_label], axis=2)
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
point_embedding = tf.where(
labels[..., None] != -10,
point_embedding,
tf.zeros_like(point_embedding),
)
point_embedding = tf.where(
(labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
)
point_embedding = tf.where(
(labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
)
return point_embedding
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = boxes.shape[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where(
tf.range(corner_embedding.shape[2])[None, None, :, None] == 0,
self.point_embed[2][0],
self.point_embed[3][0],
)
return corner_embedding
def call(
self,
batch_size: Optional[int],
input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
input_labels: Optional[tf.Tensor],
input_boxes: Optional[tf.Tensor],
input_masks: Optional[tf.Tensor],
) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`tf.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`tf.Tensor`, *optional*):
boxes to embed
masks (`tf.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = tf.zeros(
(batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
)
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None:
batch_size = input_boxes.shape[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
else:
sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
if input_masks is not None:
dense_embeddings = self.mask_embed(input_masks)
else:
dense_embeddings = self.no_mask_embed[0]
dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
dense_embeddings = tf.tile(
dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
)
if sparse_embeddings is None:
sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
return sparse_embeddings, dense_embeddings
class TFSamVisionAttention(tf.keras.layers.Layer):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size, **kwargs):
super().__init__(**kwargs)
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
if window_size == 0
else (window_size, window_size)
)
self.input_size = input_size
self.num_attention_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = tf.keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
self.proj = tf.keras.layers.Dense(config.hidden_size, name="proj")
self.use_rel_pos = config.use_rel_pos
if self.use_rel_pos:
if input_size is None:
raise ValueError("Input size must be provided if using relative positional encoding.")
def build(self, input_shape):
if self.input_size is not None:
# initialize relative positional embeddings
self.rel_pos_h = self.add_weight(
shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
)
self.rel_pos_w = self.add_weight(
shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
)
super().build(input_shape)
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`tf.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = tf.image.resize(
tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
size=(max_rel_dist, rel_pos.shape[1]),
method="bilinear",
)
rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
def add_decomposed_rel_pos(
self,
attn: tf.Tensor,
query: tf.Tensor,
rel_pos_h: tf.Tensor,
rel_pos_w: tf.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> tf.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`tf.Tensor`):
attention map.
query (`tf.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`tf.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`tf.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`tf.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = shape_list(query)
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
return attn
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
batch_size, height, width, _ = shape_list(hidden_states)
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
# q, k, v with shape (batch_size * nHead, height * width, channel)
query, key, value = tf.unstack(
tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
)
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if training:
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
else:
attn_probs = attn_weights
attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
attn_output = tf.reshape(attn_output, (batch_size, height, width, -1))
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
class TFSamVisionLayer(tf.keras.layers.Layer):
def __init__(self, config, window_size, **kwargs):
super().__init__(**kwargs)
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
self.attn = TFSamVisionAttention(config, window_size, name="attn")
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.window_size = window_size
def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
batch_size, height, width, channel = shape_list(hidden_states)
pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size
if pad_h > 0 or pad_w > 0:
hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
pad_height, pad_width = height + pad_h, width + pad_w
hidden_states = tf.reshape(
hidden_states,
[batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
)
windows = tf.reshape(
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
)
return windows, (pad_height, pad_width)
def window_unpartition(
self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
) -> tf.Tensor:
pad_height, pad_width = padding_shape
height, width = original_shape
batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
hidden_states = tf.reshape(
windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
)
hidden_states = tf.reshape(
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
)
if pad_height > height or pad_width > width:
hidden_states = hidden_states[:, :height, :width, :]
return hidden_states
def call(
self,
hidden_states: tf.Tensor,
output_attentions: Optional[bool] = False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if self.window_size > 0:
height, width = hidden_states.shape[1], hidden_states.shape[2]
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
hidden_states, attn_weights = self.attn(
hidden_states=hidden_states,
output_attentions=output_attentions,
training=training,
)
if self.window_size > 0:
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
hidden_states = residual + hidden_states
layernorm_output = self.layer_norm2(hidden_states)
hidden_states = hidden_states + self.mlp(layernorm_output)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class TFSamVisionNeck(tf.keras.layers.Layer):
def __init__(self, config: SamVisionConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.conv1 = tf.keras.layers.Conv2D(
config.output_channels,
kernel_size=1,
use_bias=False,
name="conv1",
)
self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
self.conv2 = tf.keras.layers.Conv2D(
config.output_channels,
kernel_size=3,
padding="same",
use_bias=False,
name="conv2",
)
self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")
def call(self, hidden_states):
hidden_states = self.conv1(hidden_states)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
return hidden_states
class TFSamVisionEncoder(tf.keras.layers.Layer):
def __init__(self, config: SamVisionConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.image_size = config.image_size
self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
self.pos_embed = None
self.layers = []
for i in range(config.num_hidden_layers):
layer = TFSamVisionLayer(
config,
window_size=config.window_size if i not in config.global_attn_indexes else 0,
name=f"layers_._{i}",
)
self.layers.append(layer)
self.neck = TFSamVisionNeck(config, name="neck")
def build(self, input_shape):
if self.config.use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = self.add_weight(
shape=[
1,
self.config.image_size // self.config.patch_size,
self.config.image_size // self.config.patch_size,
self.config.hidden_size,
],
initializer="zeros",
trainable=True,
name="pos_embed",
)
super().build(input_shape)
def get_input_embeddings(self):
return self.patch_embed
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[Tuple, TFSamVisionEncoderOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.patch_embed(pixel_values)
if self.pos_embed is not None:
hidden_states = hidden_states + self.pos_embed
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.neck(hidden_states)
if not return_dict:
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_self_attentions,)
return outputs
return TFSamVisionEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class TFSamPreTrainedModel(TFPreTrainedModel):
config_class = SamConfig
base_model_prefix = "sam"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(
3,
self.config.vision_config.num_channels,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
),
dtype=tf.float32,
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
SAM_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
general usage and behavior.
Parameters:
config ([`SamConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
SAM_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
details.
input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
better results. The points can be obtained by passing a list of list of list to the processor that will
create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
input point), the third dimension is the number of points per segmentation mask (it is possible to pass
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
coordinates of the point. If a different number of points is passed either for each image, or for each
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
computation of the embedding will be skipped for these points using the labels.
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
official implementation, there are 3 types of labels
- `1`: the point is a point that contains the object of interest
- `0`: the point is a point that does not contain the object of interest
- `-1`: the point corresponds to the background
We added the label:
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder
The padding labels should be automatically done by the processor.
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
order (`x1`, `y1`, `x2`, `y2`):
- `x1`: the x coordinate of the top left point of the input box
- `y1`: the y coordinate of the top left point of the input box
- `x2`: the x coordinate of the bottom right point of the input box
- `y2`: the y coordinate of the bottom right point of the input box
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
method, and then feed them to the `call` method instead of feeding the `pixel_values`.
multimask_output (`bool`, *optional*):
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
" optional 2D location and bounding boxes.",
SAM_START_DOCSTRING,
)
class TFSamModel(TFSamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
self.prompt_encoder = TFSamPromptEncoder(
config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
)
self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
self.config = config
def get_input_embeddings(self):
return self.vision_encoder.get_input_embeddings()
def get_image_wide_positional_embeddings(self):
size = self.config.prompt_encoder_config.image_embedding_size
grid = tf.ones((size, size))
y_embed = tf.math.cumsum(grid, axis=0) - 0.5
x_embed = tf.math.cumsum(grid, axis=1) - 0.5
y_embed = y_embed / size
x_embed = x_embed / size
positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width
def get_image_embeddings(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns the image embeddings by passing the pixel values through the vision encoder.
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
"""
vision_output = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_output[0]
return image_embeddings
def get_prompt_embeddings(
self,
input_points: Optional[tf.Tensor] = None,
input_labels: Optional[tf.Tensor] = None,
input_boxes: Optional[tf.Tensor] = None,
input_masks: Optional[tf.Tensor] = None,
):
r"""
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
Args:
input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
Optional input points for the prompt encoder. The padding of the point is automatically done by the
processor. `point_batch_size` refers to the number of masks that we want the model to predict per
point. The model will output `point_batch_size` times 3 masks in total.
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
processor, or can be fed by the user.
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
processor. users can also pass manually the input boxes.
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
Optional input masks for the prompt encoder.
"""
prompt_output = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
return prompt_output
@unpack_inputs
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
input_points: Optional[tf.Tensor] = None,
input_labels: Optional[tf.Tensor] = None,
input_boxes: Optional[tf.Tensor] = None,
input_masks: Optional[tf.Tensor] = None,
image_embeddings: Optional[tf.Tensor] = None,
multimask_output: bool = True,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=None,
training=False,
**kwargs,
) -> List[Dict[str, tf.Tensor]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values or image_embeddings must be provided.")
if pixel_values is not None and image_embeddings is not None:
raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
if input_points is not None and len(input_points.shape) != 4:
raise ValueError(
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
" got {}.".format(input_points.shape),
)
if input_boxes is not None and len(input_boxes.shape) != 3:
raise ValueError(
"The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
" got {}.".format(input_boxes.shape),
)
if input_points is not None and input_boxes is not None:
point_batch_size = input_points.shape[1]
box_batch_size = input_boxes.shape[1]
if point_batch_size != box_batch_size:
raise ValueError(
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
point_batch_size, box_batch_size
)
)
if pixel_values is not None:
# Ensures that later checks pass even with an all-None shape from the serving signature
pixel_values = tf.ensure_shape(
pixel_values,
[
None,
self.config.vision_config.num_channels,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
],
)
image_positional_embeddings = self.get_image_wide_positional_embeddings()
# repeat with batch size
batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)
vision_attentions = None
vision_hidden_states = None
if pixel_values is not None:
vision_outputs = self.vision_encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
training=training,
)
image_embeddings = vision_outputs["last_hidden_state"]
if output_hidden_states:
vision_hidden_states = vision_outputs["hidden_states"]
if output_attentions:
vision_attentions = vision_outputs["attentions"]
if input_points is not None and input_labels is None:
input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)
if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
raise ValueError(
"The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
batch_size=shape_list(image_embeddings)[0],
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
input_masks=input_masks,
)
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
output_attentions=output_attentions,
)
if not return_dict:
output = (iou_predictions, low_res_masks)
if output_hidden_states:
output = output + (vision_hidden_states,)
if output_attentions:
output = output + (vision_attentions, mask_decoder_attentions)
return output
return TFSamImageSegmentationOutput(
iou_scores=iou_predictions,
pred_masks=low_res_masks,
vision_hidden_states=vision_hidden_states,
vision_attentions=vision_attentions,
mask_decoder_attentions=mask_decoder_attentions,
)
def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
return TFSamImageSegmentationOutput(
iou_scores=output.iou_scores,
pred_masks=output.pred_masks,
vision_hidden_states=hs if self.config.output_hidden_states else None,
vision_attentions=attns if self.config.output_attentions else None,
mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
)
......@@ -22,12 +22,15 @@ import numpy as np
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_torch_available
from ...utils import TensorType, is_tf_available, is_torch_available
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
class SamProcessor(ProcessorMixin):
r"""
......@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin):
# pop arguments that are not used in the foward but used nevertheless
original_sizes = encoding_image_processor["original_sizes"]
if isinstance(original_sizes, torch.Tensor):
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
original_sizes = original_sizes.numpy()
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
......@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin):
input_boxes = torch.from_numpy(input_boxes)
# boxes batch size of 1 by default
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
elif return_tensors == "tf":
input_boxes = tf.convert_to_tensor(input_boxes)
# boxes batch size of 1 by default
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
encoding_image_processor.update({"input_boxes": input_boxes})
if input_points is not None:
if return_tensors == "pt":
input_points = torch.from_numpy(input_points)
# point batch size of 1 by default
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
elif return_tensors == "tf":
input_points = tf.convert_to_tensor(input_points)
# point batch size of 1 by default
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
encoding_image_processor.update({"input_points": input_points})
if input_labels is not None:
if return_tensors == "pt":
input_labels = torch.from_numpy(input_labels)
# point batch size of 1 by default
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
elif return_tensors == "tf":
input_labels = tf.convert_to_tensor(input_labels)
# point batch size of 1 by default
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
encoding_image_processor.update({"input_labels": input_labels})
return encoding_image_processor
......@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin):
it is converted to a `numpy.ndarray` and then to a `list`.
"""
if input_points is not None:
if isinstance(input_points, torch.Tensor):
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
input_points = input_points.numpy().tolist()
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
......@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin):
input_points = None
if input_labels is not None:
if isinstance(input_labels, torch.Tensor):
if hasattr(input_labels, "numpy"):
input_labels = input_labels.numpy().tolist()
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
......@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin):
input_labels = None
if input_boxes is not None:
if isinstance(input_boxes, torch.Tensor):
if hasattr(input_boxes, "numpy"):
input_boxes = input_boxes.numpy().tolist()
if (
......
......@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
# This is a very simplified functional layernorm, designed to duplicate
# the functionality of PyTorch nn.functional.layer_norm when this is needed to port
# models in Transformers.
if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
# Get mean and variance on the axis to be normalized
mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
if axis != -1:
# Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
# on every dimension except axis
shape = [1] * inputs.shape.rank
shape[axis] = shape_list(inputs)[axis]
weight = tf.reshape(weight, shape)
bias = tf.reshape(bias, shape)
# Compute layer normalization using the batch_normalization
# function.
outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset=bias,
scale=weight,
variance_epsilon=epsilon,
)
return outputs
def flatten(input, start_dim=0, end_dim=-1):
# Replicates the behavior of torch.flatten in TF
# If end_dim or start_dim is negative, count them from the end
if end_dim < 0:
end_dim += input.shape.rank
if start_dim < 0:
start_dim += input.shape.rank
if start_dim == end_dim:
return input
in_shape = tf.shape(input)
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
return tf.reshape(input, out_shape)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
......
......@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSamModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSamPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self):
pass
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
@slow
def test_model_from_pretrained(self):
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
......@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
)
def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
......@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
EXPECTED_SCORES = torch.tensor(
[
......@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
],
]
)
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow SAM model. """
import inspect
import unittest
import numpy as np
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from transformers.testing_utils import require_tf, slow
from transformers.utils import is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import SamProcessor, TFSamModel
from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
class TFSamPromptEncoderTester:
def __init__(
self,
hidden_size=32,
input_image_size=24,
patch_size=2,
mask_input_channels=4,
num_point_embeddings=4,
hidden_act="gelu",
):
self.hidden_size = hidden_size
self.input_image_size = input_image_size
self.patch_size = patch_size
self.mask_input_channels = mask_input_channels
self.num_point_embeddings = num_point_embeddings
self.hidden_act = hidden_act
def get_config(self):
return SamPromptEncoderConfig(
image_size=self.input_image_size,
patch_size=self.patch_size,
mask_input_channels=self.mask_input_channels,
hidden_size=self.hidden_size,
num_point_embeddings=self.num_point_embeddings,
hidden_act=self.hidden_act,
)
def prepare_config_and_inputs(self):
dummy_points = floats_tensor([self.batch_size, 3, 2])
config = self.get_config()
return config, dummy_points
class TFSamMaskDecoderTester:
def __init__(
self,
hidden_size=32,
hidden_act="relu",
mlp_dim=64,
num_hidden_layers=2,
num_attention_heads=4,
attention_downsample_rate=2,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=32,
layer_norm_eps=1e-6,
):
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_dim = mlp_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_downsample_rate = attention_downsample_rate
self.num_multimask_outputs = num_multimask_outputs
self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim
self.layer_norm_eps = layer_norm_eps
def get_config(self):
return SamMaskDecoderConfig(
hidden_size=self.hidden_size,
hidden_act=self.hidden_act,
mlp_dim=self.mlp_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
attention_downsample_rate=self.attention_downsample_rate,
num_multimask_outputs=self.num_multimask_outputs,
iou_head_depth=self.iou_head_depth,
iou_head_hidden_dim=self.iou_head_hidden_dim,
layer_norm_eps=self.layer_norm_eps,
)
def prepare_config_and_inputs(self):
config = self.get_config()
dummy_inputs = {
"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
}
return config, dummy_inputs
class TFSamModelTester:
def __init__(
self,
parent,
hidden_size=36,
intermediate_size=72,
projection_dim=62,
output_channels=32,
num_hidden_layers=2,
num_attention_heads=4,
num_channels=3,
image_size=24,
patch_size=2,
hidden_act="gelu",
layer_norm_eps=1e-06,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
qkv_bias=True,
mlp_ratio=4.0,
use_abs_pos=True,
use_rel_pos=True,
rel_pos_zero_init=False,
window_size=14,
global_attn_indexes=[2, 5, 8, 11],
num_pos_feats=16,
mlp_dim=None,
batch_size=2,
):
self.parent = parent
self.image_size = image_size
self.patch_size = patch_size
self.output_channels = output_channels
self.num_channels = num_channels
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.mlp_ratio = mlp_ratio
self.use_abs_pos = use_abs_pos
self.use_rel_pos = use_rel_pos
self.rel_pos_zero_init = rel_pos_zero_init
self.window_size = window_size
self.global_attn_indexes = global_attn_indexes
self.num_pos_feats = num_pos_feats
self.mlp_dim = mlp_dim
self.batch_size = batch_size
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.prompt_encoder_tester = TFSamPromptEncoderTester()
self.mask_decoder_tester = TFSamMaskDecoderTester()
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
vision_config = SamVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
initializer_factor=self.initializer_factor,
output_channels=self.output_channels,
qkv_bias=self.qkv_bias,
mlp_ratio=self.mlp_ratio,
use_abs_pos=self.use_abs_pos,
use_rel_pos=self.use_rel_pos,
rel_pos_zero_init=self.rel_pos_zero_init,
window_size=self.window_size,
global_attn_indexes=self.global_attn_indexes,
num_pos_feats=self.num_pos_feats,
mlp_dim=self.mlp_dim,
)
prompt_encoder_config = self.prompt_encoder_tester.get_config()
mask_decoder_config = self.mask_decoder_tester.get_config()
return SamConfig(
vision_config=vision_config,
prompt_encoder_config=prompt_encoder_config,
mask_decoder_config=mask_decoder_config,
)
def create_and_check_model(self, config, pixel_values):
model = TFSamModel(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
def create_and_check_get_image_features(self, config, pixel_values):
model = TFSamModel(config=config)
result = model.get_image_embeddings(pixel_values)
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
def create_and_check_get_image_hidden_states(self, config, pixel_values):
model = TFSamModel(config=config)
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=True,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=False,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFSamModel,) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFSamModel, "mask-generation": TFSamModel} if is_tf_available() else {}
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
return True
def setUp(self):
self.model_tester = TFSamModelTester(self)
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
self.prompt_encoder_config_tester = ConfigTester(
self,
config_class=SamPromptEncoderConfig,
has_text_modality=False,
num_attention_heads=12,
num_hidden_layers=2,
)
self.mask_decoder_config_tester = ConfigTester(
self, config_class=SamMaskDecoderConfig, has_text_modality=False
)
def test_config(self):
self.vision_config_tester.run_common_tests()
self.prompt_encoder_config_tester.run_common_tests()
self.mask_decoder_config_tester.run_common_tests()
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_get_image_features(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
def test_image_hidden_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
expected_vision_attention_shape = (
self.model_tester.batch_size * self.model_tester.num_attention_heads,
196,
196,
)
expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
vision_attentions = outputs.vision_attentions
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
mask_decoder_attentions = outputs.mask_decoder_attentions
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
vision_attentions = outputs.vision_attentions
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
mask_decoder_attentions = outputs.mask_decoder_attentions
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
self.assertListEqual(
list(vision_attentions[0].shape[-4:]),
list(expected_vision_attention_shape),
)
self.assertListEqual(
list(mask_decoder_attentions[0].shape[-4:]),
list(expected_mask_decoder_attention_shape),
)
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
def test_hidden_states_output(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFSamModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None):
super().check_pt_tf_outputs(
tf_outputs=tf_outputs,
pt_outputs=pt_outputs,
model_class=model_class,
tol=tol,
name=name,
attributes=attributes,
)
def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_dog_img():
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
@slow
class SamModelIntegrationTest(unittest.TestCase):
def test_inference_mask_generation_no_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
inputs = processor(images=raw_image, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4))
self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2))
def test_inference_mask_generation_one_point_one_bb(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_boxes = [[[650, 900, 1000, 1250]]]
input_points = [[[820, 1080]]]
inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4))
self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2))
def test_inference_mask_generation_batched_points_batched_images(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [
[[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
[[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
]
inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3]
EXPECTED_SCORES = np.array(
[
[
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
],
[
[0.8405, 0.6292, 0.3840],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
],
]
)
EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406])
self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3))
self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2))
def test_inference_mask_generation_one_point_one_bb_zero(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_boxes = [[[620, 900, 1000, 1255]]]
input_points = [[[820, 1080]]]
labels = [[0]]
inputs = processor(
images=raw_image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=labels,
return_tensors="tf",
)
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4))
def test_inference_mask_generation_one_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [[[400, 650]]]
input_labels = [[1]]
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4))
# With no label
input_points = [[[400, 650]]]
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4))
def test_inference_mask_generation_two_points(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [[[400, 650], [800, 650]]]
input_labels = [[1, 1]]
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
# no labels
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
def test_inference_mask_generation_two_points_batched(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [[[400, 650], [800, 650]], [[400, 650]]]
input_labels = [[1, 1], [1]]
inputs = processor(
images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="tf"
)
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4))
self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4))
def test_inference_mask_generation_one_box(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_boxes = [[[75, 275, 1725, 850]]]
inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4))
def test_inference_mask_generation_batched_image_one_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
raw_dog_image = prepare_dog_img()
input_points = [[[820, 1080]], [[220, 470]]]
inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores_batched = tf.squeeze(outputs.iou_scores)
input_points = [[[220, 470]]]
inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores_single = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4))
def test_inference_mask_generation_two_points_point_batch(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
# fmt: off
input_points = tf.convert_to_tensor([[[400, 650]], [[220, 470]]])
# fmt: on
input_points = tf.expand_dims(input_points, 0)
inputs = processor(raw_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
iou_scores = outputs.iou_scores
self.assertTrue(iou_scores.shape == (1, 2, 3))
self.assertTrue(
np.allclose(
iou_scores.numpy(),
np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]),
atol=1e-4,
rtol=1e-4,
)
)
def test_inference_mask_generation_three_boxes_point_batch(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
# fmt: off
input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]])
EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]])
# fmt: on
input_boxes = tf.expand_dims(input_boxes, 0)
inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf")
outputs = model(**inputs)
iou_scores = outputs.iou_scores
self.assertTrue(iou_scores.shape == (1, 3, 3))
self.assertTrue(np.allclose(iou_scores.numpy(), EXPECTED_IOU, atol=1e-4, rtol=1e-4))
......@@ -17,8 +17,14 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_tf,
require_torch,
require_torchvision,
require_vision,
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
if is_vision_available():
......@@ -29,6 +35,9 @@ if is_vision_available():
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
@require_vision
@require_torchvision
......@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase):
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
@require_vision
@require_tf
class TFSamProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, SamImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor
input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
@require_tf
def test_post_process_masks(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = [tf.ones((1, 3, 5, 5))]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf")
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
masks = processor.post_process_masks(
dummy_masks,
tf.convert_to_tensor(original_sizes),
tf.convert_to_tensor(reshaped_input_size),
return_tensors="tf",
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
# should also work with np
dummy_masks = [np.ones((1, 3, 5, 5))]
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(tf.errors.InvalidArgumentError):
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
@require_vision
@require_torchvision
class SamProcessorEquivalenceTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
@is_pt_tf_cross_test
def test_post_process_masks_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
pt_dummy_masks = [torch.tensor(dummy_masks)]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
tf_masks = processor.post_process_masks(
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
)
pt_masks = processor.post_process_masks(
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
)
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
@is_pt_tf_cross_test
def test_image_processor_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment