Unverified Commit 6b093283 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Fix duplicate arguments passed to dummy inputs in ONNX export (#16045)

* Fix duplicate arguments passed to dummy inputs in ONNX export

* Fix M2M100 ONNX config

* Ensure we check PreTrained model only if torch is available

* Remove TensorFlow tests for models without PyTorch parity
parent ba21001f
......@@ -198,13 +198,13 @@ class M2M100OnnxConfig(OnnxSeq2SeqConfigWithPast):
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
......
......@@ -22,6 +22,7 @@ import numpy as np
from packaging.version import Version, parse
from ..file_utils import TensorType, is_tf_available, is_torch_available, is_torch_onnx_dict_inputs_support_available
from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import logging
from .config import OnnxConfig
......@@ -100,11 +101,17 @@ def export_pytorch(
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
preprocessor = tokenizer
if issubclass(type(model), PreTrainedModel):
import torch
from torch.onnx import export as onnx_export
......@@ -123,9 +130,7 @@ def export_pytorch(
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(
preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH
)
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
......@@ -213,11 +218,15 @@ def export_tensorflow(
import onnx
import tf2onnx
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
preprocessor = tokenizer
model.config.return_dict = True
......@@ -229,7 +238,7 @@ def export_tensorflow(
setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match
model_inputs = config.generate_dummy_inputs(preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW)
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
......@@ -273,11 +282,16 @@ def export(
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
"Please install torch or tensorflow first."
)
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None:
warnings.warn(
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
preprocessor = tokenizer
if is_torch_available():
from ..file_utils import torch_version
......@@ -309,16 +323,22 @@ def validate_model_outputs(
logger.info("Validating ONNX model...")
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to validatethe model outputs.")
if tokenizer is not None:
warnings.warn(
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
FutureWarning,
)
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
preprocessor = tokenizer
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
if issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(
preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH
)
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
else:
reference_model_inputs = config.generate_dummy_inputs(
preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW
)
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW)
# Create ONNX Runtime session
options = SessionOptions()
......@@ -368,7 +388,7 @@ def validate_model_outputs(
# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
if issubclass(type(reference_model), PreTrainedModel):
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
ref_value = ref_outputs_dict[name].detach().numpy()
else:
ref_value = ref_outputs_dict[name].numpy()
......@@ -402,7 +422,7 @@ def ensure_model_and_config_inputs_match(
:param model_inputs: :param config_inputs: :return:
"""
if issubclass(type(model), PreTrainedModel):
if is_torch_available() and issubclass(type(model), PreTrainedModel):
forward_parameters = signature(model.forward).parameters
else:
forward_parameters = signature(model.call).parameters
......
......@@ -196,28 +196,19 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("m2m-100", "facebook/m2m100_418M"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
}
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {}
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"),
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
def _get_models_to_test(export_models_list):
......@@ -312,13 +303,13 @@ class OnnxExportTestCaseV2(TestCase):
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
@slow
@require_tf
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
@slow
@require_tf
def test_tensorflow_export_seq2seq_with_past(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment