Unverified Commit cb7ed6e0 authored by Alberto Bégué's avatar Alberto Bégué Committed by GitHub
Browse files

Add Tensorflow handling of ONNX conversion (#13831)



* Add TensorFlow support for ONNX export

* Change documentation to mention conversion with Tensorflow

* Refactor export into export_pytorch and export_tensorflow

* Check model's type instead of framework installation to choose between TF and Pytorch
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarAlberto Bégué <alberto.begue@della.ai>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent e923917c
...@@ -62,10 +62,6 @@ Ready-made configurations include the following architectures: ...@@ -62,10 +62,6 @@ Ready-made configurations include the following architectures:
- XLM-RoBERTa - XLM-RoBERTa
- XLM-RoBERTa-XL - XLM-RoBERTa-XL
The ONNX conversion is supported for the PyTorch versions of the models. If you
would like to be able to convert a TensorFlow model, please let us know by
opening an issue.
In the next two sections, we'll show you how to: In the next two sections, we'll show you how to:
* Export a supported model using the `transformers.onnx` package. * Export a supported model using the `transformers.onnx` package.
...@@ -150,6 +146,8 @@ DistilBERT we have: ...@@ -150,6 +146,8 @@ DistilBERT we have:
["last_hidden_state"] ["last_hidden_state"]
``` ```
The approach is similar for TensorFlow models.
### Selecting features for different model topologies ### Selecting features for different model topologies
Each ready-made configuration comes with a set of _features_ that enable you to Each ready-made configuration comes with a set of _features_ that enable you to
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
from packaging.version import Version, parse from packaging.version import Version, parse
from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from transformers.file_utils import is_torch_onnx_dict_inputs_support_available from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available
from transformers.onnx.config import OnnxConfig from transformers.onnx.config import OnnxConfig
from transformers.utils import logging from transformers.utils import logging
...@@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version): ...@@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version):
) )
def export( def export_pytorch(
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
""" """
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR Export a PyTorch model to an ONNX Intermediate Representation (IR)
Args: Args:
tokenizer: tokenizer ([`PreTrainedTokenizer`]):
model: The tokenizer used for encoding the data.
config: model ([`PreTrainedModel`]):
opset: The model to export.
output: config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
Returns: Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if issubclass(type(model), PreTrainedModel):
import torch
from torch.onnx import export as onnx_export
logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
config.restore_ops()
return matched_inputs, onnx_outputs
def export_tensorflow(
tokenizer: PreTrainedTokenizer,
model: TFPreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
""" """
if not is_torch_available(): Export a TensorFlow model to an ONNX Intermediate Representation (IR)
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
Args:
import torch tokenizer ([`PreTrainedTokenizer`]):
from torch.onnx import export The tokenizer used for encoding the data.
model ([`TFPreTrainedModel`]):
from ..file_utils import torch_version The model to export.
config ([`~onnx.config.OnnxConfig`]):
if not is_torch_onnx_dict_inputs_support_available(): The ONNX configuration associated with the exported model.
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") opset (`int`):
The version of the ONNX operator set to use.
logger.info(f"Using framework PyTorch: {torch.__version__}") output (`Path`):
with torch.no_grad(): Directory to store the exported ONNX model.
model.config.return_dict = True
model.eval() Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
# Check if we need to override certain configuration item the ONNX configuration.
if config.values_override is not None: """
logger.info(f"Overriding {len(config.values_override)} configuration item(s)") import tensorflow as tf
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}") import onnx
setattr(model.config, override_config_key, override_config_value) import tf2onnx
# Ensure inputs match model.config.return_dict = True
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) # Check if we need to override certain configuration item
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) if config.values_override is not None:
onnx_outputs = list(config.outputs.keys()) logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
if not inputs_match: logger.info(f"\t- {override_config_key} -> {override_config_value}")
raise ValueError("Model and config inputs doesn't match") setattr(model.config, override_config_key, override_config_value)
config.patch_ops()
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
config.restore_ops() # Ensure inputs match
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
onnx.save(onnx_model, output.as_posix())
config.restore_ops()
return matched_inputs, onnx_outputs return matched_inputs, onnx_outputs
def export(
tokenizer: PreTrainedTokenizer,
model: Union[PreTrainedModel, TFPreTrainedModel],
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if not (is_torch_available() or is_tf_available()):
raise ImportError(
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
"Please install torch or tensorflow first."
)
if is_torch_available():
from transformers.file_utils import torch_version
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
if is_torch_available() and issubclass(type(model), PreTrainedModel):
return export_pytorch(tokenizer, model, config, opset, output)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(tokenizer, model, config, opset, output)
def validate_model_outputs( def validate_model_outputs(
config: OnnxConfig, config: OnnxConfig,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
...@@ -160,7 +260,10 @@ def validate_model_outputs( ...@@ -160,7 +260,10 @@ def validate_model_outputs(
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test # TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes. # dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) if issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
else:
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
# Create ONNX Runtime session # Create ONNX Runtime session
options = SessionOptions() options = SessionOptions()
...@@ -210,7 +313,10 @@ def validate_model_outputs( ...@@ -210,7 +313,10 @@ def validate_model_outputs(
# Check the shape and values match # Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs): for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].detach().numpy() if issubclass(type(reference_model), PreTrainedModel):
ref_value = ref_outputs_dict[name].detach().numpy()
else:
ref_value = ref_outputs_dict[name].numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":') logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape # Shape
...@@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match( ...@@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match(
:param model_inputs: :param config_inputs: :return: :param model_inputs: :param config_inputs: :return:
""" """
forward_parameters = signature(model.forward).parameters if issubclass(type(model), PreTrainedModel):
forward_parameters = signature(model.forward).parameters
else:
forward_parameters = signature(model.call).parameters
model_inputs_set = set(model_inputs) model_inputs_set = set(model_inputs)
# We are fine if config_inputs has more keys than model_inputs # We are fine if config_inputs has more keys than model_inputs
......
from functools import partial, reduce from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type from typing import Callable, Dict, Optional, Tuple, Type, Union
from .. import PretrainedConfig, is_torch_available from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
...@@ -24,7 +24,6 @@ from .config import OnnxConfig ...@@ -24,7 +24,6 @@ from .config import OnnxConfig
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_available(): if is_torch_available():
from transformers import PreTrainedModel
from transformers.models.auto import ( from transformers.models.auto import (
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
...@@ -35,9 +34,20 @@ if is_torch_available(): ...@@ -35,9 +34,20 @@ if is_torch_available():
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForTokenClassification,
) )
elif is_tf_available():
from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
else: else:
logger.warning( logger.warning(
"The ONNX export features are only supported for PyTorch, you will not be able to export models without it." "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
) )
...@@ -80,6 +90,17 @@ class FeaturesManager: ...@@ -80,6 +90,17 @@ class FeaturesManager:
"multiple-choice": AutoModelForMultipleChoice, "multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering, "question-answering": AutoModelForQuestionAnswering,
} }
elif is_tf_available():
_TASKS_TO_AUTOMODELS = {
"default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM,
"causal-lm": TFAutoModelForCausalLM,
"seq2seq-lm": TFAutoModelForSeq2SeqLM,
"sequence-classification": TFAutoModelForSequenceClassification,
"token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
}
else: else:
_TASKS_TO_AUTOMODELS = {} _TASKS_TO_AUTOMODELS = {}
...@@ -270,7 +291,7 @@ class FeaturesManager: ...@@ -270,7 +291,7 @@ class FeaturesManager:
) )
return FeaturesManager._TASKS_TO_AUTOMODELS[task] return FeaturesManager._TASKS_TO_AUTOMODELS[task]
def get_model_from_feature(feature: str, model: str) -> PreTrainedModel: def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
""" """
Attempt to retrieve a model from a model's name and the feature to be enabled. Attempt to retrieve a model from a model's name and the feature to be enabled.
...@@ -286,7 +307,9 @@ class FeaturesManager: ...@@ -286,7 +307,9 @@ class FeaturesManager:
return model_class.from_pretrained(model) return model_class.from_pretrained(model)
@staticmethod @staticmethod
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]: def check_supported_model_or_raise(
model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
) -> Tuple[str, Callable]:
""" """
Check whether or not the model has the requested features. Check whether or not the model has the requested features.
......
...@@ -4,7 +4,7 @@ from unittest import TestCase ...@@ -4,7 +4,7 @@ from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
from parameterized import parameterized from parameterized import parameterized
from transformers import AutoConfig, AutoTokenizer, is_torch_available from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
from transformers.onnx import ( from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT, EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig, OnnxConfig,
...@@ -15,11 +15,11 @@ from transformers.onnx import ( ...@@ -15,11 +15,11 @@ from transformers.onnx import (
from transformers.onnx.config import OnnxConfigWithPast from transformers.onnx.config import OnnxConfigWithPast
if is_torch_available(): if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
@require_onnx @require_onnx
...@@ -192,19 +192,44 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ...@@ -192,19 +192,44 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("marian", "Helsinki-NLP/opus-mt-en-de"), ("marian", "Helsinki-NLP/opus-mt-en-de"),
} }
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
}
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"),
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
}
def _get_models_to_test(export_models_list): def _get_models_to_test(export_models_list):
models_to_test = [] models_to_test = []
if not is_torch_available(): if is_torch_available() or is_tf_available():
# Returning some dummy test that should not be ever called because of the @require_torch decorator. for (name, model) in export_models_list:
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
name
).items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
else:
# Returning some dummy test that should not be ever called because of the @require_torch / @require_tf
# decorators.
# The reason for not returning an empty list is because parameterized.expand complains when it's empty. # The reason for not returning an empty list is because parameterized.expand complains when it's empty.
return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)] return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
for (name, model) in export_models_list:
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
name
).items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
class OnnxExportTestCaseV2(TestCase): class OnnxExportTestCaseV2(TestCase):
...@@ -212,7 +237,7 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -212,7 +237,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported Integration tests ensuring supported models are correctly exported
""" """
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
from transformers.onnx import export from transformers.onnx import export
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
...@@ -246,13 +271,13 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -246,13 +271,13 @@ class OnnxExportTestCaseV2(TestCase):
@slow @slow
@require_torch @require_torch
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow @slow
@require_torch @require_torch
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor): def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS)) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@slow @slow
...@@ -260,4 +285,24 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -260,4 +285,24 @@ class OnnxExportTestCaseV2(TestCase):
def test_pytorch_export_seq2seq_with_past( def test_pytorch_export_seq2seq_with_past(
self, test_name, name, model_name, feature, onnx_config_class_constructor self, test_name, name, model_name, feature, onnx_config_class_constructor
): ):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
@slow
@require_tf
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
@slow
@require_tf
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@slow
@require_tf
def test_tensorflow_export_seq2seq_with_past(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
...@@ -211,7 +211,7 @@ def check_onnx_model_list(overwrite=False): ...@@ -211,7 +211,7 @@ def check_onnx_model_list(overwrite=False):
current_list, start_index, end_index, lines = _find_text_in_file( current_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"), filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->", start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
end_prompt="The ONNX conversion is supported for the PyTorch versions of the models.", end_prompt="In the next two sections, we'll show you how to:",
) )
new_list = get_onnx_model_list() new_list = get_onnx_model_list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment