Unverified Commit cc5c061e authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CLI: handle multimodal inputs (#17839)

parent e8eb699e
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import os import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from importlib import import_module from importlib import import_module
...@@ -22,7 +23,17 @@ from packaging import version ...@@ -22,7 +23,17 @@ from packaging import version
import huggingface_hub import huggingface_hub
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available from .. import (
FEATURE_EXTRACTOR_MAPPING,
PROCESSOR_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoProcessor,
AutoTokenizer,
is_tf_available,
is_torch_available,
)
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -161,31 +172,58 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -161,31 +172,58 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._push = push self._push = push
self._extra_commit_description = extra_commit_description self._extra_commit_description = extra_commit_description
def get_text_inputs(self): def get_inputs(self, pt_model, config):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir) """
sample_text = ["Hi there!", "I am a batch with more than one row and different input lengths."] Returns the right inputs for the model, based on its signature.
if tokenizer.pad_token is None: """
tokenizer.pad_token = tokenizer.eos_token
pt_input = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True)
tf_input = tokenizer(sample_text, return_tensors="tf", padding=True, truncation=True)
return pt_input, tf_input
def get_audio_inputs(self): def _get_audio_input():
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
num_samples = 2
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] speech_samples = ds.sort("id").select(range(2))[:2]["audio"]
raw_samples = [x["array"] for x in speech_samples] raw_samples = [x["array"] for x in speech_samples]
pt_input = processor(raw_samples, return_tensors="pt", padding=True) return raw_samples
tf_input = processor(raw_samples, return_tensors="tf", padding=True)
return pt_input, tf_input model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys())
processor_inputs = {}
if "input_ids" in model_forward_signature:
processor_inputs.update(
{
"text": ["Hi there!", "I am a batch with more than one row and different input lengths."],
"padding": True,
"truncation": True,
}
)
if "pixel_values" in model_forward_signature:
sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"]
processor_inputs.update({"images": sample_images})
if "input_features" in model_forward_signature:
processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True})
if "input_values" in model_forward_signature: # Wav2Vec2 audio input
processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True})
model_config_class = type(pt_model.config)
if model_config_class in PROCESSOR_MAPPING:
processor = AutoProcessor.from_pretrained(self._local_dir)
if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
elif model_config_class in FEATURE_EXTRACTOR_MAPPING:
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
elif model_config_class in TOKENIZER_MAPPING:
processor = AutoTokenizer.from_pretrained(self._local_dir)
if processor.pad_token is None:
processor.pad_token = processor.eos_token
else:
raise ValueError(f"Unknown data processing type (model config type: {model_config_class})")
pt_input = processor(**processor_inputs, return_tensors="pt")
tf_input = processor(**processor_inputs, return_tensors="tf")
# Extra input requirements, in addition to the input modality
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
def get_image_inputs(self):
feature_extractor = AutoFeatureExtractor.from_pretrained(self._local_dir)
num_samples = 2
ds = load_dataset("cifar10", "plain_text", split="test")[:num_samples]["img"]
pt_input = feature_extractor(images=ds, return_tensors="pt")
tf_input = feature_extractor(images=ds, return_tensors="tf")
return pt_input, tf_input return pt_input, tf_input
def run(self): def run(self):
...@@ -218,24 +256,10 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -218,24 +256,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
except AttributeError: except AttributeError:
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.") raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
# Load models and acquire a basic input for its modality. # Load models and acquire a basic input compatible with the model.
pt_model = pt_class.from_pretrained(self._local_dir) pt_model = pt_class.from_pretrained(self._local_dir)
main_input_name = pt_model.main_input_name
if main_input_name == "input_ids":
pt_input, tf_input = self.get_text_inputs()
elif main_input_name == "pixel_values":
pt_input, tf_input = self.get_image_inputs()
elif main_input_name == "input_features":
pt_input, tf_input = self.get_audio_inputs()
else:
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config)
# Extra input requirements, in addition to the input modality
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * pt_model.config.decoder_start_token_id
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
# Confirms that cross loading PT weights into TF worked. # Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input) crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment