"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "003c4771290b00e6d14b871210c3a369edccaeed"
Unverified Commit 905e5773 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[processor] Add 'model input names' property (#20117)

* [processor] Add 'model input names' property

* add test

* no f string

* add generic property method to mixin

* copy to multimodal

* copy to vision

* tests for all audio

* remove ad-hoc tests

* style

* fix flava test

* fix test

* fix processor code
parent 68187c46
...@@ -105,3 +105,9 @@ class CLIPProcessor(ProcessorMixin): ...@@ -105,3 +105,9 @@ class CLIPProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
...@@ -122,3 +122,9 @@ class FlavaProcessor(ProcessorMixin): ...@@ -122,3 +122,9 @@ class FlavaProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
...@@ -158,3 +158,7 @@ class LayoutLMv2Processor(ProcessorMixin): ...@@ -158,3 +158,7 @@ class LayoutLMv2Processor(ProcessorMixin):
to the docstring of this method for more information. to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"]
...@@ -156,3 +156,7 @@ class LayoutLMv3Processor(ProcessorMixin): ...@@ -156,3 +156,7 @@ class LayoutLMv3Processor(ProcessorMixin):
to the docstring of this method for more information. to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "bbox", "attention_mask", "pixel_values"]
...@@ -158,3 +158,7 @@ class LayoutXLMProcessor(ProcessorMixin): ...@@ -158,3 +158,7 @@ class LayoutXLMProcessor(ProcessorMixin):
to the docstring of this method for more information. to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "bbox", "attention_mask", "image"]
...@@ -138,3 +138,8 @@ class MarkupLMProcessor(ProcessorMixin): ...@@ -138,3 +138,8 @@ class MarkupLMProcessor(ProcessorMixin):
docstring of this method for more information. docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
return tokenizer_input_names
...@@ -159,3 +159,9 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -159,3 +159,9 @@ class OwlViTProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
...@@ -106,3 +106,9 @@ class ViltProcessor(ProcessorMixin): ...@@ -106,3 +106,9 @@ class ViltProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
...@@ -127,6 +127,11 @@ class VisionTextDualEncoderProcessor(ProcessorMixin): ...@@ -127,6 +127,11 @@ class VisionTextDualEncoderProcessor(ProcessorMixin):
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property @property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def feature_extractor_class(self): def feature_extractor_class(self):
warnings.warn( warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`" "`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
......
...@@ -107,3 +107,7 @@ class XCLIPProcessor(ProcessorMixin): ...@@ -107,3 +107,7 @@ class XCLIPProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "attention_mask", "position_ids", "pixel_values"]
...@@ -227,6 +227,11 @@ class ProcessorMixin(PushToHubMixin): ...@@ -227,6 +227,11 @@ class ProcessorMixin(PushToHubMixin):
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args return args
@property
def model_input_names(self):
first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None)
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
......
...@@ -187,3 +187,16 @@ class CLIPProcessorTest(unittest.TestCase): ...@@ -187,3 +187,16 @@ class CLIPProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = CLIPProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
...@@ -231,3 +231,16 @@ class FlavaProcessorTest(unittest.TestCase): ...@@ -231,3 +231,16 @@ class FlavaProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
...@@ -19,6 +19,8 @@ import tempfile ...@@ -19,6 +19,8 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import numpy as np
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
...@@ -86,6 +88,17 @@ class LayoutLMv2ProcessorTest(unittest.TestCase): ...@@ -86,6 +88,17 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
...@@ -133,6 +146,20 @@ class LayoutLMv2ProcessorTest(unittest.TestCase): ...@@ -133,6 +146,20 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = LayoutLMv2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# add extra args
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
@slow @slow
def test_overflowing_tokens(self): def test_overflowing_tokens(self):
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
......
...@@ -19,6 +19,8 @@ import tempfile ...@@ -19,6 +19,8 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import numpy as np
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
...@@ -99,6 +101,17 @@ class LayoutLMv3ProcessorTest(unittest.TestCase): ...@@ -99,6 +101,17 @@ class LayoutLMv3ProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
...@@ -146,6 +159,20 @@ class LayoutLMv3ProcessorTest(unittest.TestCase): ...@@ -146,6 +159,20 @@ class LayoutLMv3ProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor) self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = LayoutLMv3Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# add extra args
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
# different use cases tests # different use cases tests
@require_torch @require_torch
......
...@@ -19,6 +19,8 @@ import tempfile ...@@ -19,6 +19,8 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import numpy as np
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -74,6 +76,17 @@ class LayoutXLMProcessorTest(unittest.TestCase): ...@@ -74,6 +76,17 @@ class LayoutXLMProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
...@@ -126,6 +139,20 @@ class LayoutXLMProcessorTest(unittest.TestCase): ...@@ -126,6 +139,20 @@ class LayoutXLMProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = LayoutXLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# add extra args
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
@slow @slow
def test_overflowing_tokens(self): def test_overflowing_tokens(self):
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
......
...@@ -133,6 +133,18 @@ class MarkupLMProcessorTest(unittest.TestCase): ...@@ -133,6 +133,18 @@ class MarkupLMProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor) self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MarkupLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
tokenizer.model_input_names,
msg="`processor` and `tokenizer` model input names do not match",
)
# different use cases tests # different use cases tests
@require_bs4 @require_bs4
......
...@@ -144,3 +144,15 @@ class MCTCTProcessorTest(unittest.TestCase): ...@@ -144,3 +144,15 @@ class MCTCTProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
...@@ -239,3 +239,16 @@ class OwlViTProcessorTest(unittest.TestCase): ...@@ -239,3 +239,16 @@ class OwlViTProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
...@@ -144,3 +144,15 @@ class Speech2TextProcessorTest(unittest.TestCase): ...@@ -144,3 +144,15 @@ class Speech2TextProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment