Unverified Commit 905e5773 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[processor] Add 'model input names' property (#20117)

* [processor] Add 'model input names' property

* add test

* no f string

* add generic property method to mixin

* copy to multimodal

* copy to vision

* tests for all audio

* remove ad-hoc tests

* style

* fix flava test

* fix test

* fix processor code
parent 68187c46
......@@ -168,3 +168,16 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
......@@ -137,3 +137,15 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
......@@ -367,6 +367,19 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
@staticmethod
def get_from_offsets(offsets, key):
retrieved_list = [d[key] for d in offsets]
......
......@@ -116,3 +116,15 @@ class WhisperProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment