"docs/git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "96f1b8ef751872cfe542e2a762f9b6fab7a69659"
Unverified Commit f3d99e49 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Update VisionEncoderDecoder to use an image processor (#20137)

* TrOCR processor uses an image processor

* Update VisionEncoderDecoder

* Add feature_extractor_class property
parent a44985b4
...@@ -25,28 +25,43 @@ class TrOCRProcessor(ProcessorMixin): ...@@ -25,28 +25,43 @@ class TrOCRProcessor(ProcessorMixin):
r""" r"""
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor. Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
[`TrOCRProcessor`] offers all the functionalities of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] and [`TrOCRProcessor`] offers all the functionalities of [`ViTImageProcessor`/`DeiTImageProcessor`] and
[`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the [`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for [`RobertaTokenizer`/`XLMRobertaTokenizer`]. See the [`~TrOCRProcessor.__call__`] and [`~TrOCRProcessor.decode`] for
more information. more information.
Args: Args:
feature_extractor ([`ViTFeatureExtractor`/`DeiTFeatureExtractor`]): image_processor ([`ViTImageProcessor`/`DeiTImageProcessor`]):
An instance of [`ViTFeatureExtractor`/`DeiTFeatureExtractor`]. The feature extractor is a required input. An instance of [`ViTImageProcessor`/`DeiTImageProcessor`]. The image processor is a required input.
tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]): tokenizer ([`RobertaTokenizer`/`XLMRobertaTokenizer`]):
An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input. An instance of [`RobertaTokenizer`/`XLMRobertaTokenizer`]. The tokenizer is a required input.
""" """
feature_extractor_class = "AutoFeatureExtractor" attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor, tokenizer): def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(feature_extractor, tokenizer) if "feature_extractor" in kwargs:
self.current_processor = self.feature_extractor warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v4.27, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self._in_target_context_manager = False self._in_target_context_manager = False
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
When used in normal mode, this method forwards all its arguments to AutoFeatureExtractor's When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
[`~AutoFeatureExtractor.__call__`] and returns its output. If used in the context [`~AutoImageProcessor.__call__`] and returns its output. If used in the context
[`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's [`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's
[`~TrOCRTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. [`~TrOCRTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
""" """
...@@ -64,7 +79,7 @@ class TrOCRProcessor(ProcessorMixin): ...@@ -64,7 +79,7 @@ class TrOCRProcessor(ProcessorMixin):
raise ValueError("You need to specify either an `images` or `text` input to process.") raise ValueError("You need to specify either an `images` or `text` input to process.")
if images is not None: if images is not None:
inputs = self.feature_extractor(images, *args, **kwargs) inputs = self.image_processor(images, *args, **kwargs)
if text is not None: if text is not None:
encodings = self.tokenizer(text, **kwargs) encodings = self.tokenizer(text, **kwargs)
...@@ -103,5 +118,14 @@ class TrOCRProcessor(ProcessorMixin): ...@@ -103,5 +118,14 @@ class TrOCRProcessor(ProcessorMixin):
self._in_target_context_manager = True self._in_target_context_manager = True
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
yield yield
self.current_processor = self.feature_extractor self.current_processor = self.image_processor
self._in_target_context_manager = False self._in_target_context_manager = False
@property
def feature_extractor_class(self):
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
" instead.",
FutureWarning,
)
return self.image_processor_class
...@@ -87,8 +87,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r""" ...@@ -87,8 +87,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args: Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using the vision's model's feature extractor. For example, using Pixel values. Pixel values can be obtained using the vision's model's image processor. For example, using
[`ViTFeatureExtractor`]. See [`ViTFeatureExtractor.__call__`] for details. [`ViTImageProcessor`]. See [`ViTImageProcessor.__call__`] for details.
decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
...@@ -299,17 +299,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -299,17 +299,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
Example: Example:
```python ```python
>>> from transformers import TFVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer >>> from transformers import TFVisionEncoderDecoderModel, ViTImageProcessor, GPT2Tokenizer
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("ydshieh/vit-gpt2-coco-en") >>> image_processor = ViTImageProcessor.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> decoder_tokenizer = GPT2Tokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en") >>> decoder_tokenizer = GPT2Tokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en") >>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> img = Image.open(requests.get(url, stream=True).raw) >>> img = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values # Batch size 1 >>> pixel_values = image_processor(images=img, return_tensors="tf").pixel_values # Batch size 1
>>> output_ids = model.generate( >>> output_ids = model.generate(
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True ... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
...@@ -555,11 +555,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -555,11 +555,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
Examples: Examples:
```python ```python
>>> from transformers import AutoFeatureExtractor, AutoTokenizer, TFVisionEncoderDecoderModel >>> from transformers import AutoImageProcessor, AutoTokenizer, TFVisionEncoderDecoderModel
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
...@@ -571,7 +571,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -571,7 +571,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
>>> img = Image.open(requests.get(url, stream=True).raw) >>> img = Image.open(requests.get(url, stream=True).raw)
>>> # forward >>> # forward
>>> pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values # Batch size 1 >>> pixel_values = image_processor(images=img, return_tensors="tf").pixel_values # Batch size 1
>>> decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids # Batch size 1 >>> decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) >>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
......
...@@ -92,8 +92,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r""" ...@@ -92,8 +92,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args: Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder, Pixel values. Pixel values can be obtained using an image processor (e.g. if you use ViT as the encoder,
you should use [`ViTFeatureExtractor`]). See [`ViTFeatureExtractor.__call__`] for details. you should use [`ViTImageProcessor`]). See [`ViTImageProcessor.__call__`] for details.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
...@@ -248,17 +248,17 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -248,17 +248,17 @@ class VisionEncoderDecoderModel(PreTrainedModel):
Example: Example:
```python ```python
>>> from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer >>> from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2Tokenizer
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("ydshieh/vit-gpt2-coco-en") >>> image_processor = ViTImageProcessor.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> decoder_tokenizer = GPT2Tokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en") >>> decoder_tokenizer = GPT2Tokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en") >>> model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> img = Image.open(requests.get(url, stream=True).raw) >>> img = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values # Batch size 1 >>> pixel_values = image_processor(images=img, return_tensors="pt").pixel_values # Batch size 1
>>> output_ids = model.generate( >>> output_ids = model.generate(
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True ... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
......
...@@ -96,7 +96,7 @@ VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING = r""" ...@@ -96,7 +96,7 @@ VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING = r"""
Args: Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details. [`CLIPImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. tensors for more detail.
...@@ -131,8 +131,8 @@ VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r""" ...@@ -131,8 +131,8 @@ VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r"""
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
a feature extractor (e.g. if you use ViT as the encoder, you should use [`ViTFeatureExtractor`]). See an image processor (e.g. if you use ViT as the encoder, you should use [`ViTImageProcessor`]). See
[`ViTFeatureExtractor.__call__`] for details. [`ViTImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*): return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss. Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
...@@ -267,15 +267,15 @@ class VisionTextDualEncoderModel(PreTrainedModel): ...@@ -267,15 +267,15 @@ class VisionTextDualEncoderModel(PreTrainedModel):
```python ```python
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
>>> from transformers import VisionTextDualEncoderModel, AutoFeatureExtractor >>> from transformers import VisionTextDualEncoderModel, AutoImageProcessor
>>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian") >>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = feature_extractor(images=image, return_tensors="pt") >>> inputs = image_processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs) >>> image_features = model.get_image_features(**inputs)
```""" ```"""
...@@ -316,13 +316,13 @@ class VisionTextDualEncoderModel(PreTrainedModel): ...@@ -316,13 +316,13 @@ class VisionTextDualEncoderModel(PreTrainedModel):
>>> from transformers import ( >>> from transformers import (
... VisionTextDualEncoderModel, ... VisionTextDualEncoderModel,
... VisionTextDualEncoderProcessor, ... VisionTextDualEncoderProcessor,
... ViTFeatureExtractor, ... ViTImageProcessor,
... BertTokenizer, ... BertTokenizer,
... ) ... )
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") >>> image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer) >>> processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
>>> model = VisionTextDualEncoderModel.from_vision_text_pretrained( >>> model = VisionTextDualEncoderModel.from_vision_text_pretrained(
... "google/vit-base-patch16-224", "bert-base-uncased" ... "google/vit-base-patch16-224", "bert-base-uncased"
... ) ... )
......
...@@ -16,39 +16,56 @@ ...@@ -16,39 +16,56 @@
Processor class for VisionTextDualEncoder Processor class for VisionTextDualEncoder
""" """
import warnings
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
class VisionTextDualEncoderProcessor(ProcessorMixin): class VisionTextDualEncoderProcessor(ProcessorMixin):
r""" r"""
Constructs a VisionTextDualEncoder processor which wraps a vision feature extractor and a tokenizer into a single Constructs a VisionTextDualEncoder processor which wraps an image processor and a tokenizer into a single
processor. processor.
[`VisionTextDualEncoderProcessor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`VisionTextDualEncoderProcessor`] offers all the functionalities of [`AutoImageProcessor`] and [`AutoTokenizer`].
[`AutoTokenizer`]. See the [`~VisionTextDualEncoderProcessor.__call__`] and See the [`~VisionTextDualEncoderProcessor.__call__`] and [`~VisionTextDualEncoderProcessor.decode`] for more
[`~VisionTextDualEncoderProcessor.decode`] for more information. information.
Args: Args:
feature_extractor ([`AutoFeatureExtractor`]): image_processor ([`AutoImageProcessor`]):
The feature extractor is a required input. The image processor is a required input.
tokenizer ([`PreTrainedTokenizer`]): tokenizer ([`PreTrainedTokenizer`]):
The tokenizer is a required input. The tokenizer is a required input.
""" """
feature_extractor_class = "AutoFeatureExtractor" attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor, tokenizer): def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(feature_extractor, tokenizer) if "feature_extractor" in kwargs:
self.current_processor = self.feature_extractor warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v4.27, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You have to specify an image_processor.")
if tokenizer is None:
raise ValueError("You have to specify a tokenizer.")
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
def __call__(self, text=None, images=None, return_tensors=None, **kwargs): def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to VisionTextDualEncoderTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not and `kwargs` arguments to VisionTextDualEncoderTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not
`None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to `None` to encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
AutoFeatureExtractor's [`~AutoFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the AutoImageProcessor's [`~AutoImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
doctsring of the above two methods for more information. of the above two methods for more information.
Args: Args:
text (`str`, `List[str]`, `List[List[str]]`): text (`str`, `List[str]`, `List[List[str]]`):
...@@ -85,7 +102,7 @@ class VisionTextDualEncoderProcessor(ProcessorMixin): ...@@ -85,7 +102,7 @@ class VisionTextDualEncoderProcessor(ProcessorMixin):
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
if images is not None: if images is not None:
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
if text is not None and images is not None: if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values encoding["pixel_values"] = image_features.pixel_values
...@@ -108,3 +125,12 @@ class VisionTextDualEncoderProcessor(ProcessorMixin): ...@@ -108,3 +125,12 @@ class VisionTextDualEncoderProcessor(ProcessorMixin):
Please refer to the docstring of this method for more information. Please refer to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def feature_extractor_class(self):
warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
" instead.",
FutureWarning,
)
return self.image_processor_class
...@@ -37,6 +37,7 @@ transformers_module = spec.loader.load_module() ...@@ -37,6 +37,7 @@ transformers_module = spec.loader.load_module()
AUTO_TO_BASE_CLASS_MAPPING = { AUTO_TO_BASE_CLASS_MAPPING = {
"AutoTokenizer": "PreTrainedTokenizerBase", "AutoTokenizer": "PreTrainedTokenizerBase",
"AutoFeatureExtractor": "FeatureExtractionMixin", "AutoFeatureExtractor": "FeatureExtractionMixin",
"AutoImageProcessor": "ImageProcessingMixin",
} }
......
...@@ -23,13 +23,13 @@ import numpy as np ...@@ -23,13 +23,13 @@ import numpy as np
from transformers import BertTokenizerFast from transformers import BertTokenizerFast
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer
from transformers.testing_utils import require_tokenizers, require_vision from transformers.testing_utils import require_tokenizers, require_vision
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
from transformers import VisionTextDualEncoderProcessor, ViTFeatureExtractor from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor
@require_tokenizers @require_tokenizers
...@@ -45,22 +45,22 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): ...@@ -45,22 +45,22 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
feature_extractor_map = { image_processor_map = {
"do_resize": True, "do_resize": True,
"size": 18, "size": {"height": 18, "width": 18},
"do_normalize": True, "do_normalize": True,
"image_mean": [0.5, 0.5, 0.5], "image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5],
} }
self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME)
with open(self.feature_extractor_file, "w", encoding="utf-8") as fp: with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(feature_extractor_map, fp) json.dump(image_processor_map, fp)
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_feature_extractor(self, **kwargs): def get_image_processor(self, **kwargs):
return ViTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
...@@ -76,13 +76,11 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): ...@@ -76,13 +76,11 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
return image_inputs return image_inputs
# TODO (Amy): fix me
@unittest.skip("An issue introduced in PR #19796 will be fixed by `AutoImageProcessor`")
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor() image_processor = self.get_image_processor()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname) processor.save_pretrained(self.tmpdirname)
processor = VisionTextDualEncoderProcessor.from_pretrained(self.tmpdirname) processor = VisionTextDualEncoderProcessor.from_pretrained(self.tmpdirname)
...@@ -90,19 +88,17 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): ...@@ -90,19 +88,17 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast)) self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
self.assertIsInstance(processor.feature_extractor, ViTFeatureExtractor) self.assertIsInstance(processor.image_processor, ViTImageProcessor)
# TODO (Amy): fix me
@unittest.skip("An issue introduced in PR #19796 will be fixed by `AutoImageProcessor`")
def test_save_load_pretrained_additional_features(self): def test_save_load_pretrained_additional_features(self):
processor = VisionTextDualEncoderProcessor( processor = VisionTextDualEncoderProcessor(
tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor() tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor()
) )
processor.save_pretrained(self.tmpdirname) processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = VisionTextDualEncoderProcessor.from_pretrained( processor = VisionTextDualEncoderProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
...@@ -111,28 +107,28 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): ...@@ -111,28 +107,28 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast)) self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, ViTFeatureExtractor) self.assertIsInstance(processor.image_processor, ViTImageProcessor)
def test_feature_extractor(self): def test_image_processor(self):
feature_extractor = self.get_feature_extractor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, image_processor=image_processor)
image_input = self.prepare_image_inputs() image_input = self.prepare_image_inputs()
input_feat_extract = feature_extractor(image_input, return_tensors="np") input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np") input_processor = processor(images=image_input, return_tensors="np")
for key in input_feat_extract.keys(): for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self): def test_tokenizer(self):
feature_extractor = self.get_feature_extractor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "lower newer" input_str = "lower newer"
...@@ -144,10 +140,10 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): ...@@ -144,10 +140,10 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
self.assertListEqual(encoded_tok[key], encoded_processor[key]) self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_processor(self): def test_processor(self):
feature_extractor = self.get_feature_extractor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "lower newer" input_str = "lower newer"
image_input = self.prepare_image_inputs() image_input = self.prepare_image_inputs()
...@@ -161,10 +157,10 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): ...@@ -161,10 +157,10 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
processor() processor()
def test_tokenizer_decode(self): def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, image_processor=image_processor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment