"examples/vscode:/vscode.git/clone" did not exist on "378435584ead2780e17af4ec1ed3353543e85ccb"
Unverified Commit 7579a52b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Small sam patch (#22920)



* patch

* add test

* move tests

* cover more cases (will fail nw update the code)

* style

* fix

* Update src/transformers/models/sam/image_processing_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/sam/image_processing_sam.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add better check

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 5166c30e
...@@ -378,12 +378,13 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -378,12 +378,13 @@ class SamImageProcessor(BaseImageProcessor):
Remove padding and upscale masks to the original image size. Remove padding and upscale masks to the original image size.
Args: Args:
masks (`torch.Tensor`): masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`torch.Tensor`): original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The original size of the images before resizing for input to the model, in (height, width) format. The original sizes of each image before it was resized to the model's expected input shape, in (height,
reshaped_input_sizes (`torch.Tensor`): width) format.
The size of the image input to the model, in (height, width) format. Used to remove padding. reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0): mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks. The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`): binarize (`bool`, *optional*, defaults to `True`):
...@@ -398,9 +399,16 @@ class SamImageProcessor(BaseImageProcessor): ...@@ -398,9 +399,16 @@ class SamImageProcessor(BaseImageProcessor):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
pad_size = self.pad_size if pad_size is None else pad_size pad_size = self.pad_size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"]) target_image_size = (pad_size["height"], pad_size["width"])
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
original_sizes = original_sizes.tolist()
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
reshaped_input_sizes = reshaped_input_sizes.tolist()
output_masks = [] output_masks = []
for i, original_size in enumerate(original_sizes): for i, original_size in enumerate(original_sizes):
if isinstance(masks[i], np.ndarray):
masks[i] = torch.from_numpy(masks[i])
elif not isinstance(masks[i], torch.Tensor):
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
......
...@@ -17,8 +17,8 @@ import unittest ...@@ -17,8 +17,8 @@ import unittest
import numpy as np import numpy as np
from transformers.testing_utils import require_torchvision, require_vision from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torch_available, is_vision_available
if is_vision_available(): if is_vision_available():
...@@ -26,6 +26,9 @@ if is_vision_available(): ...@@ -26,6 +26,9 @@ if is_vision_available():
from transformers import AutoProcessor, SamImageProcessor, SamProcessor from transformers import AutoProcessor, SamImageProcessor, SamProcessor
if is_torch_available():
import torch
@require_vision @require_vision
@require_torchvision @require_torchvision
...@@ -79,3 +82,31 @@ class SamProcessorTest(unittest.TestCase): ...@@ -79,3 +82,31 @@ class SamProcessorTest(unittest.TestCase):
for key in input_feat_extract.keys(): for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
@require_torch
def test_post_process_masks(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = [torch.ones((1, 3, 5, 5))]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
masks = processor.post_process_masks(
dummy_masks, torch.tensor(original_sizes), torch.tensor(reshaped_input_size)
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
# should also work with np
dummy_masks = [np.ones((1, 3, 5, 5))]
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment