Unverified Commit 67a35114 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Update PT to TF CLI for audio models (#19465)

* Update PT to TF CLI model inputs

* Get padding strategy if specified

* Make False comparison explicit
parent 8d68878c
...@@ -197,6 +197,20 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -197,6 +197,20 @@ class PTtoTFCommand(BaseTransformersCLICommand):
raw_samples = [x["array"] for x in speech_samples] raw_samples = [x["array"] for x in speech_samples]
return raw_samples return raw_samples
model_config_class = type(pt_model.config)
if model_config_class in PROCESSOR_MAPPING:
processor = AutoProcessor.from_pretrained(self._local_dir)
if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
elif model_config_class in FEATURE_EXTRACTOR_MAPPING:
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
elif model_config_class in TOKENIZER_MAPPING:
processor = AutoTokenizer.from_pretrained(self._local_dir)
if processor.pad_token is None:
processor.pad_token = processor.eos_token
else:
raise ValueError(f"Unknown data processing type (model config type: {model_config_class})")
model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys()) model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys())
processor_inputs = {} processor_inputs = {}
if "input_ids" in model_forward_signature: if "input_ids" in model_forward_signature:
...@@ -211,24 +225,20 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -211,24 +225,20 @@ class PTtoTFCommand(BaseTransformersCLICommand):
sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"] sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"]
processor_inputs.update({"images": sample_images}) processor_inputs.update({"images": sample_images})
if "input_features" in model_forward_signature: if "input_features" in model_forward_signature:
processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True}) feature_extractor_signature = inspect.signature(processor.feature_extractor).parameters
if "input_values" in model_forward_signature: # Wav2Vec2 audio input # Pad to the largest input length by default but take feature extractor default
processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True}) # padding value if it exists e.g. "max_length" and is not False or None
if "padding" in feature_extractor_signature:
model_config_class = type(pt_model.config) default_strategy = feature_extractor_signature["padding"].default
if model_config_class in PROCESSOR_MAPPING: if default_strategy is not False and default_strategy is not None:
processor = AutoProcessor.from_pretrained(self._local_dir) padding_strategy = default_strategy
if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
elif model_config_class in FEATURE_EXTRACTOR_MAPPING:
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
elif model_config_class in TOKENIZER_MAPPING:
processor = AutoTokenizer.from_pretrained(self._local_dir)
if processor.pad_token is None:
processor.pad_token = processor.eos_token
else: else:
raise ValueError(f"Unknown data processing type (model config type: {model_config_class})") padding_strategy = True
else:
padding_strategy = True
processor_inputs.update({"audio": _get_audio_input(), "padding": padding_strategy})
if "input_values" in model_forward_signature: # Wav2Vec2 audio input
processor_inputs.update({"audio": _get_audio_input(), "padding": True})
pt_input = processor(**processor_inputs, return_tensors="pt") pt_input = processor(**processor_inputs, return_tensors="pt")
tf_input = processor(**processor_inputs, return_tensors="tf") tf_input = processor(**processor_inputs, return_tensors="tf")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment