Unverified Commit 2aa3cd93 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

[RFC] Laying down building stone for more flexible ONNX export capabilities (#11786)



* Laying down building stone for more flexible ONNX export capabilities

* Ability to provide a map of config key to override before exporting.

* Makes it possible to export BART with/without past keys.

* Supports simple mathematical syntax for OnnxVariable.repeated

* Effectively apply value override from onnx config for model

* Supports export with additional features such as with-past for seq2seq

* Store the output path directly in the args for uniform usage across.

* Make BART_ONNX_CONFIG_* constants and fix imports.

* Support BERT model.

* Use tokenizer for more flexibility in defining the inputs of a model.

* Add TODO as remainder to provide the batch/sequence_length as CLI args

* Enable optimizations to be done on the model.

* Enable GPT2 + past

* Improve model validation with outputs containing nested structures

* Enable Roberta

* Enable Albert

* Albert requires opset >= 12

* BERT-like models requires opset >= 12

* Remove double printing.

* Enable XLM-Roberta

* Enable DistilBERT

* Disable optimization by default

* Fix missing setattr when applying optimizer_features

* Add value field to OnnxVariable to define constant input (not from tokenizers)

* Add T5 support.

* Simplify model type retrieval

* Example exporting token_classification pipeline for DistilBERT.

* Refactoring to package `transformers.onnx`

* Solve circular dependency & __main__

* Remove unnecessary imports in `__init__`

* Licences

* Use @Narsil's suggestion to forward the model's configuration to the ONNXConfig to avoid interpolation.

* Onnx export v2 fixes (#12388)

* Tiny fixes
Remove `convert_pytorch` from onnxruntime-less runtimes
Correct reference to model

* Style

* Fix Copied from

* LongFormer ONNX config.

* Removed optimizations

* Remvoe bad merge relicas.

* Remove unused constants.

* Remove some deleted constants from imports.

* Fix unittest to remove usage of PyTorch model for onnx.utils.

* Fix distilbert export

* Enable ONNX export test for supported model.

* Style.

* Fix lint.

* Enable all supported default models.

* GPT2 only has one output

* Fix bad property name when overriding config.

* Added unittests and docstrings.

* Disable with_past tests for now.

* Enable outputs validation for default export.

* Remove graph opt lvls.

* Last commit with on-going past commented.

* Style.

* Disabled `with_past` for now

* Remove unused imports.

* Remove framework argument

* Remove TFPreTrainedModel reference

* Add documentation

* Add onnxruntime tests to CircleCI

* Add test

* Rename `convert_pytorch` to `export`

* Use OrderedDict for dummy inputs

* WIP Wav2Vec2

* Revert "WIP Wav2Vec2"

This reverts commit f665efb04c92525c3530e589029f0ae7afdf603e.

* Style

* Use OrderedDict for I/O

* Style.

* Specify OrderedDict documentation.

* Style :)
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 0085e712
......@@ -345,6 +345,32 @@ jobs:
- '~/.cache/pip'
- run: python -m pytest -sv ./tests/ -m is_staging_test
run_tests_onnxruntime:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
TRANSFORMERS_IS_CI: yes
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime]
- save_cache:
key: v0.4-onnx-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch ./tests/* -k onnx | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
build_doc:
working_directory: ~/transformers
docker:
......@@ -485,6 +511,7 @@ workflows:
- run_tests_flax
- run_tests_pipelines_torch
- run_tests_pipelines_tf
- run_tests_onnxruntime
- run_tests_hub
- build_doc
- deploy_doc: *workflow_filters
......
......@@ -21,11 +21,137 @@ Projects `ONNX (Open Neural Network eXchange) <http://onnx.ai>`_ and `ONNXRuntim
unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety
of hardware and dedicated optimizations.
Starting from transformers v2.10.0 we partnered with ONNX Runtime to provide an easy export of transformers models to
the ONNX format. You can have a look at the effort by looking at our joint blog post `Accelerate your NLP pipelines
using Hugging Face Transformers and ONNX Runtime
<https://medium.com/microsoftazure/accelerate-your-nlp-pipelines-using-hugging-face-transformers-and-onnx-runtime-2443578f4333>`_.
Configuration-based approach
-----------------------------------------------------------------------------------------------------------------------
Transformers v4.9.0 introduces a new package: ``transformers.onnx``. This package allows converting checkpoints to an
ONNX graph by leveraging configuration objects. These configuration objects come ready made for a number of model
architectures, and are made to be easily extendable to other architectures.
Ready-made configurations include the following models:
- ALBERT
- BART
- BERT
- DistilBERT
- GPT-2
- RoBERTa
- T5
- XLM-RoBERTa
This conversion is handled with the PyTorch version of models - it, therefore, requires PyTorch to be installed. If you
would like to be able to convert from TensorFlow, please let us know by opening an issue.
.. note::
The models showcased here are close to fully feature complete, but do lack some features that are currently in
development. Namely, the ability to handle the past key values for decoder models is currently in the works.
Converting an ONNX model using the ``transformers.onnx`` package
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The package may be used as a Python module:
.. code-block::
python -m transformers.onnx --help
usage: Hugging Face ONNX Exporter tool [-h] -m MODEL -f {pytorch} [--features {default}] [--opset OPSET] [--atol ATOL] output
positional arguments:
output Path indicating where to store generated ONNX model.
optional arguments:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Model's name of path on disk to load.
-f {pytorch}, --framework {pytorch}
Framework to use when exporting. Possible values are: {'pytorch'}
--features {default} Export the model with some additional features.
--opset OPSET ONNX opset version to export the model with (default 12).
--atol ATOL Absolute difference tolerance when validating the model.
Exporting a checkpoint using a ready-made configuration can be done as follows:
.. code-block::
python -m transformers.onnx -f pytorch --model=bert-base-cased onnx/bert-base-cased/
This exports an ONNX graph of the mentioned checkpoint. Here it is `bert-base-cased`, but it can be any model from the
hub, or a local path.
It will be exported under ``onnx/bert-base-cased``. You should see similar logs:
.. code-block::
Validating ONNX model...
-[✓] ONNX model outputs' name match reference model ({'pooler_output', 'last_hidden_state'}
- Validating ONNX Model output "last_hidden_state":
-[] (2, 8, 768) matchs (2, 8, 768)
-[] all values close (atol: 0.0001)
- Validating ONNX Model output "pooler_output":
-[] (2, 768) matchs (2, 768)
-[] all values close (atol: 0.0001)
All good, model saved at: onnx/bert-base-cased/model.onnx
Implementing a custom configuration for an unsupported architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's take a look at the changes necessary to add a custom configuration for an unsupported architecture. Firstly, we
will need a custom ONNX configuration object that details the model inputs and outputs. The BERT ONNX configuration is
visible below:
.. code-block::
class BertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
Let's understand what's happening here. This configuration has two properties: the inputs, and the outputs.
The inputs return a dictionary, where each key corresponds to an expected input, and each value indicates the axis of
that input.
For BERT, there are three necessary inputs. These three inputs are of similar shape, which is made up of two
dimensions: the batch is the first dimension, and the second is the sequence.
The outputs return a similar dictionary, where, once again, each key corresponds to an expected output, and each value
indicates the axis of that output.
Once this is done, a single step remains: adding this configuration object to the initialisation of the model class,
and to the general ``transformers`` initialisation.
An important fact to notice is the use of `OrderedDict` in both inputs and outputs properties. This is a requirements
as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are
match against there position in the returned `BaseModelOutputX` instance.
Graph conversion
-----------------------------------------------------------------------------------------------------------------------
.. note::
The approach detailed here is bing deprecated. We recommend you follow the part above for an up to date approach.
Exporting a model is done through the script `convert_graph_to_onnx.py` at the root of the transformers sources. The
following command shows how easy it is to export a BERT model from the library, simply run:
......
......@@ -148,9 +148,30 @@ except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_onnx_available = (
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
)
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
try:
_coloredlogs_available = importlib_metadata.version("coloredlogs")
logger.debug(f"Successfully imported sympy version {_coloredlogs_available}")
except importlib_metadata.PackageNotFoundError:
_coloredlogs_available = False
sympy_available = importlib.util.find_spec("sympy") is not None
try:
_sympy_available = importlib_metadata.version("sympy")
logger.debug(f"Successfully imported sympy version {_sympy_available}")
except importlib_metadata.PackageNotFoundError:
_sympy_available = False
_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
try:
_keras2onnx_version = importlib_metadata.version("keras2onnx")
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
except importlib_metadata.PackageNotFoundError:
_keras2onnx_available = False
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
......@@ -292,6 +313,14 @@ def is_tf_available():
return _tf_available
def is_coloredlogs_available():
return _coloredlogs_available
def is_keras2onnx_available():
return _keras2onnx_available
def is_onnx_available():
return _onnx_available
......
......@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = {
"configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
"configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig", "AlbertOnnxConfig"],
}
if is_sentencepiece_available():
......@@ -67,7 +67,7 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
if is_sentencepiece_available():
from .tokenization_albert import AlbertTokenizer
......
......@@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ALBERT model configuration """
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
......@@ -151,3 +154,20 @@ class AlbertConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.classifier_dropout_prob = classifier_dropout_prob
self.position_embedding_type = position_embedding_type
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
class AlbertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
......@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = {
"configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig"],
"configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig", "BartOnnxConfig"],
"tokenization_bart": ["BartTokenizer"],
}
......@@ -53,7 +53,7 @@ if is_flax_available():
]
if TYPE_CHECKING:
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, BartOnnxConfig
from .tokenization_bart import BartTokenizer
if is_tokenizers_available():
......
......@@ -14,8 +14,11 @@
# limitations under the License.
""" BART model configuration """
import warnings
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
......@@ -186,3 +189,32 @@ class BartConfig(PretrainedConfig):
@property
def hidden_size(self) -> int:
return self.d_model
class BartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
......@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = {
"configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig"],
"configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"],
"tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
}
......@@ -77,7 +77,7 @@ if is_flax_available():
]
if TYPE_CHECKING:
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
if is_tokenizers_available():
......
......@@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT model configuration """
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -154,3 +157,19 @@ class BertConfig(PretrainedConfig):
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
class BertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
......@@ -22,7 +22,11 @@ from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available,
_import_structure = {
"configuration_distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig"],
"configuration_distilbert": [
"DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DistilBertConfig",
"DistilBertOnnxConfig",
],
"tokenization_distilbert": ["DistilBertTokenizer"],
}
......@@ -56,7 +60,11 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_distilbert import (
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig,
DistilBertOnnxConfig,
)
from .tokenization_distilbert import DistilBertTokenizer
if is_tokenizers_available():
......
......@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" DistilBERT model configuration """
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -135,3 +138,18 @@ class DistilBertConfig(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.n_layers
class DistilBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
......@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = {
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config"],
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
"tokenization_gpt2": ["GPT2Tokenizer"],
}
......@@ -55,7 +55,7 @@ if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
from .tokenization_gpt2 import GPT2Tokenizer
if is_tokenizers_available():
......
......@@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" OpenAI GPT-2 configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
......@@ -195,3 +200,61 @@ class GPT2Config(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.n_layer
class GPT2OnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}
return common_outputs
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch = common_inputs["input_ids"].shape[0]
ordered_inputs["past_key_values"] = [
(
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
)
for _ in range(self._config.n_layer)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
return ordered_inputs
......@@ -22,7 +22,11 @@ from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available,
_import_structure = {
"configuration_longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig"],
"configuration_longformer": [
"LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"LongformerConfig",
"LongformerOnnxConfig",
],
"tokenization_longformer": ["LongformerTokenizer"],
}
......@@ -57,7 +61,11 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_longformer import (
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
LongformerConfig,
LongformerOnnxConfig,
)
from .tokenization_longformer import LongformerTokenizer
if is_tokenizers_available():
......
......@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Longformer configuration """
from collections import OrderedDict
from typing import List, Mapping, Union
from typing import List, Union
from ...onnx import OnnxConfig
from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig
......@@ -69,3 +70,18 @@ class LongformerConfig(RobertaConfig):
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
super().__init__(sep_token_id=sep_token_id, **kwargs)
self.attention_window = attention_window
class LongformerOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
......@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = {
"configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig"],
"configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaOnnxConfig"],
"tokenization_roberta": ["RobertaTokenizer"],
}
......@@ -68,7 +68,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaOnnxConfig
from .tokenization_roberta import RobertaTokenizer
if is_tokenizers_available():
......
......@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" RoBERTa configuration """
from collections import OrderedDict
from typing import Mapping
from ...onnx import OnnxConfig
from ...utils import logging
from ..bert.configuration_bert import BertConfig
......@@ -62,3 +65,18 @@ class RobertaConfig(BertConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
"""Constructs RobertaConfig."""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
class RobertaOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
......@@ -29,7 +29,7 @@ from ...file_utils import (
_import_structure = {
"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"],
}
if is_sentencepiece_available():
......@@ -66,7 +66,7 @@ if is_flax_available():
if TYPE_CHECKING:
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config, T5OnnxConfig
if is_sentencepiece_available():
from .tokenization_t5 import T5Tokenizer
......
......@@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" T5 model configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging
......@@ -132,3 +137,66 @@ class T5Config(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.num_layers
class T5OnnxConfig(OnnxConfigWithPast):
def __init__(self, config: PretrainedConfig, use_past: bool = False):
super().__init__(config, use_past)
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch"}),
("decoder_attention_mask", {0: "batch"}),
]
)
if self.use_past:
for i in range(self._config.num_layers):
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
return common_inputs
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
]
)
if self.use_past:
for i in range(self._config.num_layers):
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
return common_outputs
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
if self.use_past:
raise NotImplementedError()
# Generate encoder inputs
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
# Generate decoder inputs
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
return dict(**encoder_inputs, **decoder_inputs)
......@@ -28,7 +28,11 @@ from ...file_utils import (
_import_structure = {
"configuration_xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
"configuration_xlm_roberta": [
"XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP",
"XLMRobertaConfig",
"XLMRobertaOnnxConfig",
],
}
if is_sentencepiece_available():
......@@ -62,7 +66,11 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from .configuration_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMRobertaConfig,
XLMRobertaOnnxConfig,
)
if is_sentencepiece_available():
from .tokenization_xlm_roberta import XLMRobertaTokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment