Unverified Commit 0ae58204 authored by İdil Sülo's avatar İdil Sülo Committed by GitHub
Browse files

Add visual prompt to processor of CLIPSeg model (#20816)

Adds visual_prompt argument to CLIPSegProcessor to enable image-guided segmentation
parent 2da82bb4
...@@ -56,7 +56,7 @@ class CLIPSegProcessor(ProcessorMixin): ...@@ -56,7 +56,7 @@ class CLIPSegProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
def __call__(self, text=None, images=None, return_tensors=None, **kwargs): def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs):
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
...@@ -73,6 +73,10 @@ class CLIPSegProcessor(ProcessorMixin): ...@@ -73,6 +73,10 @@ class CLIPSegProcessor(ProcessorMixin):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width. number of channels, H and W are image height and width.
visual_prompt (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The visual prompt image or batch of images to be prepared. Each visual prompt image can be a PIL image,
NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape
(C, H, W), where C is a number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are: If set, will return tensors of a particular framework. Acceptable values are:
...@@ -91,21 +95,37 @@ class CLIPSegProcessor(ProcessorMixin): ...@@ -91,21 +95,37 @@ class CLIPSegProcessor(ProcessorMixin):
`None`). `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if text is None and visual_prompt is None and images is None:
raise ValueError("You have to specify either text, visual prompt or images.")
if text is None and images is None: if text is not None and visual_prompt is not None:
raise ValueError("You have to specify either text or images. Both cannot be none.") raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.")
if text is not None: if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
if visual_prompt is not None:
prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs)
if images is not None: if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
if text is not None and images is not None: if visual_prompt is not None and images is not None:
encoding = {
"pixel_values": image_features.pixel_values,
"conditional_pixel_values": prompt_features.pixel_values,
}
return encoding
elif text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values encoding["pixel_values"] = image_features.pixel_values
return encoding return encoding
elif text is not None: elif text is not None:
return encoding return encoding
elif visual_prompt is not None:
encoding = {
"conditional_pixel_values": prompt_features.pixel_values,
}
return encoding
else: else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
......
...@@ -157,7 +157,7 @@ class CLIPSegProcessorTest(unittest.TestCase): ...@@ -157,7 +157,7 @@ class CLIPSegProcessorTest(unittest.TestCase):
for key in encoded_tok.keys(): for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key]) self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_processor(self): def test_processor_text(self):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -174,6 +174,23 @@ class CLIPSegProcessorTest(unittest.TestCase): ...@@ -174,6 +174,23 @@ class CLIPSegProcessorTest(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
processor() processor()
def test_processor_visual_prompt(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
image_input = self.prepare_image_inputs()
visual_prompt_input = self.prepare_image_inputs()
inputs = processor(images=image_input, visual_prompt=visual_prompt_input)
self.assertListEqual(list(inputs.keys()), ["pixel_values", "conditional_pixel_values"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_tokenizer_decode(self): def test_tokenizer_decode(self):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment